use crate::allocator::{Allocator, AllocatorStats, AllocatorStatsDelta};
use crate::element::TensorElement;
use crate::environment::{EnvInner, Environment};
use crate::initializer::OwnedInitializer;
use crate::io_binding::{IoBinding, OutputValue};
use crate::memory::MemoryInfo;
use crate::prepacked::{PrepackedWeightsContainer, PrepackedWeightsInner};
use crate::run_options::RunOptions;
use crate::session_options::SessionOptions;
use crate::tensor::{AllocatedTensor, OwnedValue, RunInput, TensorBuffer};
use crate::{Error, Result, api, check, sys};
use futures_util::task::AtomicWaker;
use std::cell::UnsafeCell;
use std::ffi::{CStr, CString, c_char, c_void};
use std::marker::PhantomData;
use std::ptr;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
const STACK_IO_HANDLES: usize = 8;
const AUTO_ALIGNED_BUFFER_THRESHOLD_BYTES: usize = 1 << 20;
const AUTO_ALIGNED_BUFFER_ALIGNMENT: usize = 4096;
const AUTO_HUGEPAGE_BUFFER_THRESHOLD_BYTES: usize = 2 << 20;
const HUGEPAGE_BUFFER_ALIGNMENT: usize = 2 << 20;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LaneBufferPolicy {
Vec,
Prefaulted,
Aligned { alignment: usize },
AlignedPrefaulted { alignment: usize },
HugePage,
HugePagePrefaulted,
AlignedHugePagePrefaulted { alignment: usize },
AlignedMlocked { alignment: usize },
AlignedMlockedPrefaulted { alignment: usize },
HugePageMlocked,
HugePageMlockedPrefaulted,
AlignedHugePageMlockedPrefaulted { alignment: usize },
Auto,
}
impl Default for LaneBufferPolicy {
#[inline]
fn default() -> Self {
Self::Auto
}
}
struct CachedIo {
onnx_type: sys::OnnxType,
elem_type: sys::ElementType,
count: Option<usize>,
dims: Vec<i64>,
symbolic: Vec<Option<String>>,
}
pub struct Session {
sess: *mut sys::SessionHandle,
input_names: Vec<CString>,
input_ptrs: Vec<*const c_char>,
input_meta: Vec<CachedIo>,
output_names: Vec<CString>,
output_ptrs: Vec<*const c_char>,
output_meta: Vec<CachedIo>,
run_opts: RunOptions,
_owned_initializers: Vec<OwnedInitializer>,
_prepacked_weights: Option<Arc<PrepackedWeightsInner>>,
_env: Arc<EnvInner>,
}
pub struct PreparedRun<'s, 'i> {
session: &'s Session,
input_handles: Vec<*const sys::ValueHandle>,
output_handles: Vec<*mut sys::ValueHandle>,
outputs: Vec<Option<OwnedValue>>,
_inputs: PhantomData<&'i dyn RunInput>,
}
pub struct PreparedIoBinding<'s, 'v> {
session: &'s Session,
binding: IoBinding,
_values: PhantomData<&'v ()>,
}
pub struct TensorIoLane<'s, T: TensorElement> {
session: &'s Session,
binding: IoBinding,
inputs: Vec<TensorBuffer<T>>,
outputs: Vec<TensorBuffer<T>>,
}
pub struct AllocatedOutputTensorIoLane<'s, T: TensorElement> {
session: &'s Session,
binding: IoBinding,
inputs: Vec<TensorBuffer<T>>,
outputs: Vec<AllocatedTensor<T>>,
}
pub struct DeviceOutputTensorIoLane<'s, T: TensorElement> {
session: &'s Session,
binding: IoBinding,
inputs: Vec<TensorBuffer<T>>,
outputs: Vec<OwnedValue>,
}
pub struct AllocatedTensorIoLane<'s, T: TensorElement> {
session: &'s Session,
binding: IoBinding,
inputs: Vec<AllocatedTensor<T>>,
outputs: Vec<AllocatedTensor<T>>,
}
pub struct StaticTensorIoLane<'s, T: TensorElement, const INPUTS: usize, const OUTPUTS: usize> {
session: &'s Session,
binding: IoBinding,
inputs: [TensorBuffer<T>; INPUTS],
outputs: [TensorBuffer<T>; OUTPUTS],
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct LaneRunAllocatorStats {
pub before: AllocatorStats,
pub after: AllocatorStats,
}
impl LaneRunAllocatorStats {
#[inline]
pub fn delta(&self) -> AllocatorStatsDelta {
self.before.diff(&self.after)
}
}
pub(crate) fn lane_tensor_buffer<T>(
shape: &[i64], mem: &MemoryInfo, policy: LaneBufferPolicy,
) -> Result<TensorBuffer<T>>
where
T: TensorElement + Clone + Default,
{
let bytes = lane_shape_bytes::<T>(shape)?;
match resolve_lane_buffer_policy(policy, bytes) {
LaneBufferPolicy::Vec => TensorBuffer::zeros(shape, mem),
LaneBufferPolicy::Prefaulted => TensorBuffer::zeros_prefaulted(shape, mem),
LaneBufferPolicy::Aligned { alignment } => {
TensorBuffer::zeros_aligned(shape, alignment, mem)
},
LaneBufferPolicy::AlignedPrefaulted { alignment } => {
TensorBuffer::zeros_aligned_prefaulted(shape, alignment, mem)
},
LaneBufferPolicy::HugePage => {
TensorBuffer::zeros_aligned_hugepage(shape, HUGEPAGE_BUFFER_ALIGNMENT, mem)
},
LaneBufferPolicy::HugePagePrefaulted => {
TensorBuffer::zeros_aligned_hugepage_prefaulted(shape, HUGEPAGE_BUFFER_ALIGNMENT, mem)
},
LaneBufferPolicy::AlignedHugePagePrefaulted { alignment } => {
TensorBuffer::zeros_aligned_hugepage_prefaulted(shape, alignment, mem)
},
LaneBufferPolicy::AlignedMlocked { alignment } => {
TensorBuffer::zeros_aligned_mlocked(shape, alignment, mem)
},
LaneBufferPolicy::AlignedMlockedPrefaulted { alignment } => {
TensorBuffer::zeros_aligned_mlocked_prefaulted(shape, alignment, mem)
},
LaneBufferPolicy::HugePageMlocked => {
TensorBuffer::zeros_aligned_hugepage_mlocked(shape, HUGEPAGE_BUFFER_ALIGNMENT, mem)
},
LaneBufferPolicy::HugePageMlockedPrefaulted => {
TensorBuffer::zeros_aligned_hugepage_mlocked_prefaulted(
shape,
HUGEPAGE_BUFFER_ALIGNMENT,
mem,
)
},
LaneBufferPolicy::AlignedHugePageMlockedPrefaulted { alignment } => {
TensorBuffer::zeros_aligned_hugepage_mlocked_prefaulted(shape, alignment, mem)
},
LaneBufferPolicy::Auto => unreachable!("auto lane buffer policy must resolve first"),
}
}
fn resolve_lane_buffer_policy(policy: LaneBufferPolicy, bytes: usize) -> LaneBufferPolicy {
match policy {
LaneBufferPolicy::Auto if bytes >= AUTO_HUGEPAGE_BUFFER_THRESHOLD_BYTES => {
LaneBufferPolicy::HugePagePrefaulted
},
LaneBufferPolicy::Auto if bytes >= AUTO_ALIGNED_BUFFER_THRESHOLD_BYTES => {
LaneBufferPolicy::AlignedPrefaulted {
alignment: AUTO_ALIGNED_BUFFER_ALIGNMENT,
}
},
LaneBufferPolicy::Auto => LaneBufferPolicy::Vec,
other => other,
}
}
fn lane_shape_bytes<T: TensorElement>(shape: &[i64]) -> Result<usize> {
let mut count = 1usize;
for &dim in shape {
if dim < 0 {
return Err(Error::new(
-1,
format!("zrt: lane buffers require concrete shapes, got {shape:?}"),
));
}
count = count
.checked_mul(dim as usize)
.ok_or_else(|| Error::new(-1, "zrt: lane buffer element count overflows usize"))?;
}
count
.checked_mul(std::mem::size_of::<T>())
.ok_or_else(|| Error::new(-1, "zrt: lane buffer byte size overflows usize"))
}
impl Session {
pub fn new(env: &Environment, model_path: &str, opts: SessionOptions) -> Result<Self> {
let cpath = CString::new(model_path)
.map_err(|_| crate::Error::new(-1, "model path contains a NUL"))?;
let opts_handle = build_session_options_for_env(env, &opts)?;
let mut sess: *mut sys::SessionHandle = ptr::null_mut();
let create = check(unsafe {
api().create_session()(
env.as_ptr(),
cpath.as_ptr(),
opts_handle as *const sys::SessionOptionsHandle,
&mut sess,
)
});
unsafe { api().release_session_options()(opts_handle) };
create?;
Self::from_handle(sess, env.share())
}
pub fn from_bytes(env: &Environment, model_data: &[u8], opts: SessionOptions) -> Result<Self> {
let opts_handle = build_session_options_for_env(env, &opts)?;
let mut sess: *mut sys::SessionHandle = ptr::null_mut();
let create = check(unsafe {
api().create_session_from_array()(
env.as_ptr(),
model_data.as_ptr() as *const c_void,
model_data.len(),
opts_handle as *const sys::SessionOptionsHandle,
&mut sess,
)
});
unsafe { api().release_session_options()(opts_handle) };
create?;
Self::from_handle(sess, env.share())
}
pub fn new_with_prepacked_weights(
env: &Environment, model_path: &str, opts: SessionOptions,
prepacked: &PrepackedWeightsContainer,
) -> Result<Self> {
Self::new_with_prepacked_weights_and_owned_initializers(
env,
model_path,
opts,
prepacked,
Vec::new(),
)
}
pub fn new_with_owned_initializers(
env: &Environment, model_path: &str, opts: SessionOptions,
initializers: Vec<OwnedInitializer>,
) -> Result<Self> {
let cpath = CString::new(model_path)
.map_err(|_| crate::Error::new(-1, "model path contains a NUL"))?;
let opts_handle = build_session_options_for_env(env, &opts)?;
let create = (|| -> Result<*mut sys::SessionHandle> {
add_owned_initializers(opts_handle, &initializers)?;
let mut sess: *mut sys::SessionHandle = ptr::null_mut();
check(unsafe {
api().create_session()(
env.as_ptr(),
cpath.as_ptr(),
opts_handle as *const sys::SessionOptionsHandle,
&mut sess,
)
})?;
Ok(sess)
})();
unsafe { api().release_session_options()(opts_handle) };
let sess = create?;
Self::from_handle_with_resources(sess, env.share(), initializers, None)
}
pub fn new_with_prepacked_weights_and_owned_initializers(
env: &Environment, model_path: &str, opts: SessionOptions,
prepacked: &PrepackedWeightsContainer, initializers: Vec<OwnedInitializer>,
) -> Result<Self> {
let cpath = CString::new(model_path)
.map_err(|_| crate::Error::new(-1, "model path contains a NUL"))?;
let opts_handle = build_session_options_for_env(env, &opts)?;
let create = (|| -> Result<*mut sys::SessionHandle> {
add_owned_initializers(opts_handle, &initializers)?;
let mut sess: *mut sys::SessionHandle = ptr::null_mut();
check(unsafe {
api().create_session_with_prepacked_weights_container()(
env.as_ptr(),
cpath.as_ptr(),
opts_handle as *const sys::SessionOptionsHandle,
prepacked.as_mut_ptr(),
&mut sess,
)
})?;
Ok(sess)
})();
unsafe { api().release_session_options()(opts_handle) };
let sess = create?;
Self::from_handle_with_resources(sess, env.share(), initializers, Some(prepacked.share()))
}
pub fn from_bytes_with_prepacked_weights(
env: &Environment, model_data: &[u8], opts: SessionOptions,
prepacked: &PrepackedWeightsContainer,
) -> Result<Self> {
Self::from_bytes_with_prepacked_weights_and_owned_initializers(
env,
model_data,
opts,
prepacked,
Vec::new(),
)
}
pub fn from_bytes_with_owned_initializers(
env: &Environment, model_data: &[u8], opts: SessionOptions,
initializers: Vec<OwnedInitializer>,
) -> Result<Self> {
let opts_handle = build_session_options_for_env(env, &opts)?;
let create = (|| -> Result<*mut sys::SessionHandle> {
add_owned_initializers(opts_handle, &initializers)?;
let mut sess: *mut sys::SessionHandle = ptr::null_mut();
check(unsafe {
api().create_session_from_array()(
env.as_ptr(),
model_data.as_ptr() as *const c_void,
model_data.len(),
opts_handle as *const sys::SessionOptionsHandle,
&mut sess,
)
})?;
Ok(sess)
})();
unsafe { api().release_session_options()(opts_handle) };
let sess = create?;
Self::from_handle_with_resources(sess, env.share(), initializers, None)
}
pub fn from_bytes_with_prepacked_weights_and_owned_initializers(
env: &Environment, model_data: &[u8], opts: SessionOptions,
prepacked: &PrepackedWeightsContainer, initializers: Vec<OwnedInitializer>,
) -> Result<Self> {
let opts_handle = build_session_options_for_env(env, &opts)?;
let create = (|| -> Result<*mut sys::SessionHandle> {
add_owned_initializers(opts_handle, &initializers)?;
let mut sess: *mut sys::SessionHandle = ptr::null_mut();
check(unsafe {
api().create_session_from_array_with_prepacked_weights_container()(
env.as_ptr(),
model_data.as_ptr() as *const c_void,
model_data.len(),
opts_handle as *const sys::SessionOptionsHandle,
prepacked.as_mut_ptr(),
&mut sess,
)
})?;
Ok(sess)
})();
unsafe { api().release_session_options()(opts_handle) };
let sess = create?;
Self::from_handle_with_resources(sess, env.share(), initializers, Some(prepacked.share()))
}
fn from_handle(sess: *mut sys::SessionHandle, env: Arc<EnvInner>) -> Result<Self> {
Self::from_handle_with_resources(sess, env, Vec::new(), None)
}
fn from_handle_with_resources(
sess: *mut sys::SessionHandle, env: Arc<EnvInner>,
owned_initializers: Vec<OwnedInitializer>,
prepacked_weights: Option<Arc<PrepackedWeightsInner>>,
) -> Result<Self> {
let sess = crate::ensure_non_null(sess, "session")?;
let result = (|| {
let alloc = Allocator::get_default()?;
let (input_names, input_ptrs) = collect_io_names(sess, true, &alloc)?;
let (output_names, output_ptrs) = collect_io_names(sess, false, &alloc)?;
let input_meta = collect_io_meta(sess, true, input_ptrs.len())?;
let output_meta = collect_io_meta(sess, false, output_ptrs.len())?;
Ok(Self {
sess,
input_names,
input_ptrs,
input_meta,
output_names,
output_ptrs,
output_meta,
run_opts: RunOptions::new()?,
_owned_initializers: owned_initializers,
_prepacked_weights: prepacked_weights,
_env: env,
})
})();
if result.is_err() {
unsafe { api().release_session()(sess) };
}
result
}
#[cfg(feature = "model-editor")]
fn refresh_io_metadata(&mut self) -> Result<()> {
let alloc = Allocator::get_default()?;
let (input_names, input_ptrs) = collect_io_names(self.sess, true, &alloc)?;
let (output_names, output_ptrs) = collect_io_names(self.sess, false, &alloc)?;
let input_meta = collect_io_meta(self.sess, true, input_ptrs.len())?;
let output_meta = collect_io_meta(self.sess, false, output_ptrs.len())?;
self.input_names = input_names;
self.input_ptrs = input_ptrs;
self.input_meta = input_meta;
self.output_names = output_names;
self.output_ptrs = output_ptrs;
self.output_meta = output_meta;
Ok(())
}
pub fn metadata(&self) -> Result<crate::metadata::ModelMetadata> {
let mut meta: *mut sys::ModelMetadataHandle = ptr::null_mut();
check(unsafe {
api().session_get_model_metadata()(self.sess as *const sys::SessionHandle, &mut meta)
})?;
let meta = crate::ensure_non_null(meta, "model metadata")?;
Ok(unsafe { crate::metadata::ModelMetadata::from_owning(meta) })
}
pub fn profiling_start_time_ns(&self) -> Result<u64> {
let mut out = 0u64;
check(unsafe {
api().session_get_profiling_start_time_ns()(
self.sess as *const sys::SessionHandle,
&mut out,
)
})?;
Ok(out)
}
pub fn end_profiling(&self) -> Result<String> {
let alloc = Allocator::get_default()?;
let mut raw: *mut c_char = ptr::null_mut();
check(unsafe { api().session_end_profiling()(self.sess, alloc.alloc, &mut raw) })?;
if raw.is_null() {
return Err(Error::new(-1, "zrt: ORT returned null profiling path"));
}
let path = unsafe { crate::cstr_to_string(raw, "profiling path") };
let free = unsafe { alloc.free(raw as *mut c_void) };
match (path, free) {
(Ok(path), Ok(())) => Ok(path),
(Err(err), _) => Err(err),
(_, Err(err)) => Err(err),
}
}
#[inline]
pub(crate) fn as_ptr(&self) -> *mut sys::SessionHandle {
self.sess
}
#[inline]
pub fn input_count(&self) -> usize {
self.input_ptrs.len()
}
#[inline]
pub fn output_count(&self) -> usize {
self.output_ptrs.len()
}
pub fn input_name(&self, i: usize) -> Result<&str> {
self.input_names
.get(i)
.ok_or_else(|| {
Error::new(
-1,
format!(
"zrt: input index {i} out of range ({} inputs)",
self.input_count()
),
)
})?
.to_str()
.map_err(|_| Error::new(-1, format!("zrt: input name {i} is not valid UTF-8")))
}
pub fn output_name(&self, i: usize) -> Result<&str> {
self.output_names
.get(i)
.ok_or_else(|| {
Error::new(
-1,
format!(
"zrt: output index {i} out of range ({} outputs)",
self.output_count()
),
)
})?
.to_str()
.map_err(|_| Error::new(-1, format!("zrt: output name {i} is not valid UTF-8")))
}
#[inline]
pub fn input_meta(&self, i: usize) -> Result<(sys::OnnxType, sys::ElementType, Option<usize>)> {
let m = self.input_meta.get(i).ok_or_else(|| {
Error::new(
-1,
format!(
"zrt: input index {i} out of range ({} inputs)",
self.input_count()
),
)
})?;
Ok((m.onnx_type, m.elem_type, m.count))
}
#[inline]
pub fn output_meta(
&self, i: usize,
) -> Result<(sys::OnnxType, sys::ElementType, Option<usize>)> {
let m = self.output_meta.get(i).ok_or_else(|| {
Error::new(
-1,
format!(
"zrt: output index {i} out of range ({} outputs)",
self.output_count()
),
)
})?;
Ok((m.onnx_type, m.elem_type, m.count))
}
#[inline]
pub fn input_shape(&self, i: usize) -> Result<&[i64]> {
Ok(&self
.input_meta
.get(i)
.ok_or_else(|| {
Error::new(
-1,
format!(
"zrt: input index {i} out of range ({} inputs)",
self.input_count()
),
)
})?
.dims)
}
#[inline]
pub fn output_shape(&self, i: usize) -> Result<&[i64]> {
Ok(&self
.output_meta
.get(i)
.ok_or_else(|| {
Error::new(
-1,
format!(
"zrt: output index {i} out of range ({} outputs)",
self.output_count()
),
)
})?
.dims)
}
#[inline]
pub fn input_symbolic_dims(&self, i: usize) -> Result<&[Option<String>]> {
Ok(&self
.input_meta
.get(i)
.ok_or_else(|| {
Error::new(
-1,
format!(
"zrt: input index {i} out of range ({} inputs)",
self.input_count()
),
)
})?
.symbolic)
}
#[inline]
pub fn output_symbolic_dims(&self, i: usize) -> Result<&[Option<String>]> {
Ok(&self
.output_meta
.get(i)
.ok_or_else(|| {
Error::new(
-1,
format!(
"zrt: output index {i} out of range ({} outputs)",
self.output_count()
),
)
})?
.symbolic)
}
pub fn run(&self, inputs: &[&dyn RunInput], outputs: &mut [Option<OwnedValue>]) -> Result<()> {
self.run_impl(inputs, outputs, self.run_opts.as_ptr())
}
pub fn prepare_run<'s, 'i>(
&'s self, inputs: &[&'i dyn RunInput],
) -> Result<PreparedRun<'s, 'i>> {
self.check_input_count(inputs.len())?;
Ok(PreparedRun {
session: self,
input_handles: inputs.iter().map(|v| v.as_value_ptr()).collect(),
output_handles: vec![ptr::null_mut(); self.output_count()],
outputs: (0..self.output_count()).map(|_| None).collect(),
_inputs: PhantomData,
})
}
pub fn prepare_io_binding<'s, 'v>(
&'s self, inputs: &[&'v dyn RunInput], outputs: &[&'v OutputValue<'_>],
) -> Result<PreparedIoBinding<'s, 'v>> {
self.check_input_count(inputs.len())?;
self.check_output_count(outputs.len(), "output count")?;
let mut binding = IoBinding::new(self)?;
for (i, input) in inputs.iter().enumerate() {
binding.bind_input(self.input_name(i)?, *input)?;
}
for (i, output) in outputs.iter().enumerate() {
binding.bind_output(self.output_name(i)?, output)?;
}
Ok(PreparedIoBinding {
session: self,
binding,
_values: PhantomData,
})
}
pub fn prepare_io_binding_buffers<'s, 'v, T: TensorElement>(
&'s self, inputs: &[&'v dyn RunInput], outputs: &[&'v TensorBuffer<T>],
) -> Result<PreparedIoBinding<'s, 'v>> {
self.check_input_count(inputs.len())?;
self.check_output_count(outputs.len(), "output count")?;
let mut binding = IoBinding::new(self)?;
for (i, input) in inputs.iter().enumerate() {
binding.bind_input(self.input_name(i)?, *input)?;
}
for (i, output) in outputs.iter().enumerate() {
binding.bind_output_buffer(self.output_name(i)?, output)?;
}
Ok(PreparedIoBinding {
session: self,
binding,
_values: PhantomData,
})
}
pub fn prepare_tensor_io_lane<T>(
&self, mem: &MemoryInfo, input_shapes: &[&[i64]], output_shapes: &[&[i64]],
) -> Result<TensorIoLane<'_, T>>
where
T: TensorElement + Clone + Default,
{
self.prepare_tensor_io_lane_with_buffer_policy(
mem,
input_shapes,
output_shapes,
LaneBufferPolicy::Auto,
)
}
pub fn prepare_tensor_io_lane_with_buffer_policy<T>(
&self, mem: &MemoryInfo, input_shapes: &[&[i64]], output_shapes: &[&[i64]],
policy: LaneBufferPolicy,
) -> Result<TensorIoLane<'_, T>>
where
T: TensorElement + Clone + Default,
{
self.check_input_count(input_shapes.len())?;
self.check_output_count(output_shapes.len(), "output shape count")?;
let inputs: Vec<TensorBuffer<T>> = input_shapes
.iter()
.map(|shape| lane_tensor_buffer(shape, mem, policy))
.collect::<Result<_>>()?;
let outputs: Vec<TensorBuffer<T>> = output_shapes
.iter()
.map(|shape| lane_tensor_buffer(shape, mem, policy))
.collect::<Result<_>>()?;
let mut binding = IoBinding::new(self)?;
for (i, input) in inputs.iter().enumerate() {
binding.bind_input(self.input_name(i)?, input)?;
}
for (i, output) in outputs.iter().enumerate() {
binding.bind_output_buffer(self.output_name(i)?, output)?;
}
Ok(TensorIoLane {
session: self,
binding,
inputs,
outputs,
})
}
pub fn prepare_allocated_output_tensor_io_lane<T>(
&self, input_mem: &MemoryInfo, output_mem: &MemoryInfo, input_shapes: &[&[i64]],
output_shapes: &[&[i64]],
) -> Result<AllocatedOutputTensorIoLane<'_, T>>
where
T: TensorElement + Clone + Default,
{
self.prepare_allocated_output_tensor_io_lane_with_buffer_policy(
input_mem,
output_mem,
input_shapes,
output_shapes,
LaneBufferPolicy::Auto,
)
}
pub fn prepare_allocated_output_tensor_io_lane_with_buffer_policy<T>(
&self, input_mem: &MemoryInfo, output_mem: &MemoryInfo, input_shapes: &[&[i64]],
output_shapes: &[&[i64]], input_policy: LaneBufferPolicy,
) -> Result<AllocatedOutputTensorIoLane<'_, T>>
where
T: TensorElement + Clone + Default,
{
self.check_input_count(input_shapes.len())?;
self.check_output_count(output_shapes.len(), "output shape count")?;
let inputs: Vec<TensorBuffer<T>> = input_shapes
.iter()
.map(|shape| lane_tensor_buffer(shape, input_mem, input_policy))
.collect::<Result<_>>()?;
let outputs: Vec<AllocatedTensor<T>> = output_shapes
.iter()
.map(|shape| AllocatedTensor::for_session(self, output_mem, shape))
.collect::<Result<_>>()?;
let mut binding = IoBinding::new(self)?;
for (i, input) in inputs.iter().enumerate() {
binding.bind_input(self.input_name(i)?, input)?;
}
for (i, output) in outputs.iter().enumerate() {
binding.bind_output_allocated(self.output_name(i)?, output)?;
}
Ok(AllocatedOutputTensorIoLane {
session: self,
binding,
inputs,
outputs,
})
}
pub fn prepare_device_output_tensor_io_lane<T>(
&self, input_mem: &MemoryInfo, output_mem: &MemoryInfo, input_shapes: &[&[i64]],
) -> Result<DeviceOutputTensorIoLane<'_, T>>
where
T: TensorElement + Clone + Default,
{
self.prepare_device_output_tensor_io_lane_with_buffer_policy(
input_mem,
output_mem,
input_shapes,
LaneBufferPolicy::Auto,
)
}
pub fn prepare_device_output_tensor_io_lane_with_buffer_policy<T>(
&self, input_mem: &MemoryInfo, output_mem: &MemoryInfo, input_shapes: &[&[i64]],
input_policy: LaneBufferPolicy,
) -> Result<DeviceOutputTensorIoLane<'_, T>>
where
T: TensorElement + Clone + Default,
{
self.check_input_count(input_shapes.len())?;
let inputs: Vec<TensorBuffer<T>> = input_shapes
.iter()
.map(|shape| lane_tensor_buffer(shape, input_mem, input_policy))
.collect::<Result<_>>()?;
let mut binding = IoBinding::new(self)?;
for (i, input) in inputs.iter().enumerate() {
binding.bind_input(self.input_name(i)?, input)?;
}
for i in 0..self.output_count() {
binding.bind_output_device(self.output_name(i)?, output_mem)?;
}
Ok(DeviceOutputTensorIoLane {
session: self,
binding,
inputs,
outputs: Vec::new(),
})
}
pub fn prepare_allocated_tensor_io_lane<T>(
&self, input_mem: &MemoryInfo, output_mem: &MemoryInfo, input_shapes: &[&[i64]],
output_shapes: &[&[i64]],
) -> Result<AllocatedTensorIoLane<'_, T>>
where
T: TensorElement + Clone + Default,
{
self.check_input_count(input_shapes.len())?;
self.check_output_count(output_shapes.len(), "output shape count")?;
let inputs: Vec<AllocatedTensor<T>> = input_shapes
.iter()
.map(|shape| AllocatedTensor::for_session(self, input_mem, shape))
.collect::<Result<_>>()?;
let outputs: Vec<AllocatedTensor<T>> = output_shapes
.iter()
.map(|shape| AllocatedTensor::for_session(self, output_mem, shape))
.collect::<Result<_>>()?;
let mut binding = IoBinding::new(self)?;
for (i, input) in inputs.iter().enumerate() {
binding.bind_input(self.input_name(i)?, input)?;
}
for (i, output) in outputs.iter().enumerate() {
binding.bind_output_allocated(self.output_name(i)?, output)?;
}
Ok(AllocatedTensorIoLane {
session: self,
binding,
inputs,
outputs,
})
}
pub fn prepare_static_tensor_io_lane<T, const INPUTS: usize, const OUTPUTS: usize>(
&self, mem: &MemoryInfo, input_shapes: [&[i64]; INPUTS], output_shapes: [&[i64]; OUTPUTS],
) -> Result<StaticTensorIoLane<'_, T, INPUTS, OUTPUTS>>
where
T: TensorElement + Clone + Default,
{
self.prepare_static_tensor_io_lane_with_buffer_policy(
mem,
input_shapes,
output_shapes,
LaneBufferPolicy::Auto,
)
}
pub fn prepare_static_tensor_io_lane_with_buffer_policy<
T,
const INPUTS: usize,
const OUTPUTS: usize,
>(
&self, mem: &MemoryInfo, input_shapes: [&[i64]; INPUTS], output_shapes: [&[i64]; OUTPUTS],
policy: LaneBufferPolicy,
) -> Result<StaticTensorIoLane<'_, T, INPUTS, OUTPUTS>>
where
T: TensorElement + Clone + Default,
{
self.check_input_count(INPUTS)?;
self.check_output_count(OUTPUTS, "output shape count")?;
let inputs: [TensorBuffer<T>; INPUTS] = input_shapes
.iter()
.map(|shape| lane_tensor_buffer(shape, mem, policy))
.collect::<Result<Vec<_>>>()?
.try_into()
.map_err(|_| Error::new(-1, "zrt: failed to build fixed input buffer array"))?;
let outputs: [TensorBuffer<T>; OUTPUTS] = output_shapes
.iter()
.map(|shape| lane_tensor_buffer(shape, mem, policy))
.collect::<Result<Vec<_>>>()?
.try_into()
.map_err(|_| Error::new(-1, "zrt: failed to build fixed output buffer array"))?;
let mut binding = IoBinding::new(self)?;
for (i, input) in inputs.iter().enumerate() {
binding.bind_input(self.input_name(i)?, input)?;
}
for (i, output) in outputs.iter().enumerate() {
binding.bind_output_buffer(self.output_name(i)?, output)?;
}
Ok(StaticTensorIoLane {
session: self,
binding,
inputs,
outputs,
})
}
pub fn prepare_tensor_io_lanes<T>(
&self, mem: &MemoryInfo, input_shapes: &[&[i64]], output_shapes: &[&[i64]], lanes: usize,
) -> Result<Vec<TensorIoLane<'_, T>>>
where
T: TensorElement + Clone + Default,
{
self.prepare_tensor_io_lanes_with_buffer_policy(
mem,
input_shapes,
output_shapes,
lanes,
LaneBufferPolicy::Auto,
)
}
pub fn prepare_tensor_io_lanes_with_buffer_policy<T>(
&self, mem: &MemoryInfo, input_shapes: &[&[i64]], output_shapes: &[&[i64]], lanes: usize,
policy: LaneBufferPolicy,
) -> Result<Vec<TensorIoLane<'_, T>>>
where
T: TensorElement + Clone + Default,
{
(0..lanes)
.map(|_| {
self.prepare_tensor_io_lane_with_buffer_policy(
mem,
input_shapes,
output_shapes,
policy,
)
})
.collect()
}
pub fn run_with(
&self, inputs: &[&dyn RunInput], outputs: &mut [Option<OwnedValue>], opts: &RunOptions,
) -> Result<()> {
self.run_impl(inputs, outputs, opts.as_ptr())
}
fn run_impl(
&self, inputs: &[&dyn RunInput], outputs: &mut [Option<OwnedValue>],
opts: *const sys::RunOptionsHandle,
) -> Result<()> {
self.check_input_count(inputs.len())?;
self.check_output_count(outputs.len(), "output slot count")?;
if inputs.len() <= STACK_IO_HANDLES && outputs.len() <= STACK_IO_HANDLES {
let mut in_handles = [ptr::null(); STACK_IO_HANDLES];
for (dst, input) in in_handles.iter_mut().zip(inputs.iter()) {
*dst = input.as_value_ptr();
}
let mut out_handles = [ptr::null_mut(); STACK_IO_HANDLES];
self.run_raw(
&in_handles[..inputs.len()],
&mut out_handles[..outputs.len()],
opts,
)?;
self.stamp_outputs(&out_handles[..outputs.len()], outputs)?;
} else {
let in_handles: Vec<*const sys::ValueHandle> =
inputs.iter().map(|v| v.as_value_ptr()).collect();
let mut out_handles: Vec<*mut sys::ValueHandle> =
vec![ptr::null_mut(); self.output_count()];
self.run_raw(&in_handles, &mut out_handles, opts)?;
self.stamp_outputs(&out_handles, outputs)?;
}
Ok(())
}
fn run_raw(
&self, input_handles: &[*const sys::ValueHandle],
output_handles: &mut [*mut sys::ValueHandle], opts: *const sys::RunOptionsHandle,
) -> Result<()> {
check(unsafe {
api().run()(
self.sess,
opts,
self.input_ptrs.as_ptr(),
input_handles.as_ptr(),
input_handles.len(),
self.output_ptrs.as_ptr(),
self.output_ptrs.len(),
output_handles.as_mut_ptr(),
)
})
}
fn check_input_count(&self, got: usize) -> Result<()> {
let expected = self.input_count();
if got != expected {
return Err(crate::Error::new(
-1,
format!("zrt: input count mismatch: expected {expected}, got {got}"),
));
}
Ok(())
}
fn check_output_count(&self, got: usize, what: &str) -> Result<()> {
let expected = self.output_count();
if got != expected {
return Err(crate::Error::new(
-1,
format!("zrt: {what} mismatch: expected {expected}, got {got}"),
));
}
Ok(())
}
fn stamp_outputs(
&self, handles: &[*mut sys::ValueHandle], outputs: &mut [Option<OwnedValue>],
) -> Result<()> {
for i in 0..handles.len() {
let h = handles[i];
let m = &self.output_meta[i];
let count = match m.count {
Some(count) => count,
None if m.onnx_type == sys::OnnxType::Tensor => {
match crate::type_info::tensor_type_and_shape(h as *const sys::ValueHandle)
.and_then(|shape| shape.element_count())
{
Ok(count) => count,
Err(err) => {
for &handle in &handles[i..] {
if !handle.is_null() {
unsafe { api().release_value()(handle) };
}
}
return Err(err);
},
}
},
None => 0,
};
outputs[i] = Some(OwnedValue {
value: h,
onnx_type: m.onnx_type,
elem_type: m.elem_type,
count,
});
}
Ok(())
}
pub fn run_binding(&self, binding: &crate::io_binding::IoBinding) -> Result<()> {
binding.synchronize_inputs()?;
check(unsafe {
api().run_with_binding()(self.sess, self.run_opts.as_ptr(), binding.as_ptr())
})?;
binding.synchronize_outputs()
}
pub fn run_binding_with(
&self, binding: &crate::io_binding::IoBinding, opts: &RunOptions,
) -> Result<()> {
binding.synchronize_inputs()?;
check(unsafe { api().run_with_binding()(self.sess, opts.as_ptr(), binding.as_ptr()) })?;
binding.synchronize_outputs()
}
pub fn run_async<'a>(&'a self, inputs: &'a [&'a dyn RunInput]) -> Result<RunFuture<'a>> {
self.check_input_count(inputs.len())?;
let n = self.output_count();
let in_handles: Box<[*const sys::ValueHandle]> = inputs
.iter()
.map(|v| v.as_value_ptr())
.collect::<Vec<_>>()
.into_boxed_slice();
let mut out_handles: Box<[*mut sys::ValueHandle]> =
vec![ptr::null_mut(); n].into_boxed_slice();
let in_ptr = in_handles.as_ptr();
let out_ptr = out_handles.as_mut_ptr();
let state = Arc::new(AsyncState {
result: UnsafeCell::new(None),
done: AtomicBool::new(false),
waker: AtomicWaker::new(),
_in_handles: in_handles,
_out_handles: out_handles,
});
let user_data = Arc::into_raw(state.clone()) as *mut c_void;
let started = check(unsafe {
api().run_async()(
self.sess,
self.run_opts.as_ptr(),
self.input_ptrs.as_ptr(),
in_ptr,
inputs.len(),
self.output_ptrs.as_ptr(),
self.output_ptrs.len(),
out_ptr,
Some(run_async_callback),
user_data,
)
});
if let Err(e) = started {
unsafe {
drop(Arc::from_raw(user_data as *const AsyncState));
}
return Err(e);
}
Ok(RunFuture {
state,
_borrows: std::marker::PhantomData,
})
}
}
impl PreparedRun<'_, '_> {
pub fn run(&mut self) -> Result<&[Option<OwnedValue>]> {
for slot in &mut self.outputs {
*slot = None;
}
self.output_handles.fill(ptr::null_mut());
self.session.run_raw(
&self.input_handles,
&mut self.output_handles,
self.session.run_opts.as_ptr(),
)?;
let session = self.session;
session.stamp_outputs(&self.output_handles, &mut self.outputs)?;
Ok(&self.outputs)
}
pub fn prime(&mut self, runs: usize) -> Result<()> {
for _ in 0..runs {
self.run()?;
}
Ok(())
}
pub fn outputs(&self) -> &[Option<OwnedValue>] {
&self.outputs
}
pub fn output(&self, i: usize) -> Result<Option<&OwnedValue>> {
self.outputs
.get(i)
.map(Option::as_ref)
.ok_or_else(|| Error::new(-1, format!("zrt: output index {i} out of range")))
}
}
impl PreparedIoBinding<'_, '_> {
pub fn run(&mut self) -> Result<()> {
self.session.run_binding(&self.binding)
}
pub fn prime(&mut self, runs: usize) -> Result<()> {
for _ in 0..runs {
self.run()?;
}
Ok(())
}
pub fn binding(&self) -> &IoBinding {
&self.binding
}
}
impl<T: TensorElement> TensorIoLane<'_, T> {
pub fn run(&mut self) -> Result<()> {
self.session.run_binding(&self.binding)
}
pub fn prime(&mut self, runs: usize) -> Result<()> {
for _ in 0..runs {
self.run()?;
}
Ok(())
}
pub fn run_with_allocator_stats(
&mut self, allocator: &Allocator,
) -> Result<LaneRunAllocatorStats> {
let before = allocator.stats()?;
self.run()?;
let after = allocator.stats()?;
Ok(LaneRunAllocatorStats { before, after })
}
#[inline]
pub fn input(&self, i: usize) -> Result<&[T]> {
self.inputs
.get(i)
.map(TensorBuffer::as_slice)
.ok_or_else(|| Error::new(-1, format!("zrt: lane input index {i} out of range")))
}
#[inline]
pub fn input_mut(&mut self, i: usize) -> Result<&mut [T]> {
self.inputs
.get_mut(i)
.map(TensorBuffer::as_mut_slice)
.ok_or_else(|| Error::new(-1, format!("zrt: lane input index {i} out of range")))
}
#[inline]
pub fn output(&self, i: usize) -> Result<&[T]> {
self.outputs
.get(i)
.map(TensorBuffer::as_slice)
.ok_or_else(|| Error::new(-1, format!("zrt: lane output index {i} out of range")))
}
#[inline]
pub fn output_mut(&mut self, i: usize) -> Result<&mut [T]> {
self.outputs
.get_mut(i)
.map(TensorBuffer::as_mut_slice)
.ok_or_else(|| Error::new(-1, format!("zrt: lane output index {i} out of range")))
}
#[inline]
pub fn input_buffer(&self, i: usize) -> Result<&TensorBuffer<T>> {
self.inputs
.get(i)
.ok_or_else(|| Error::new(-1, format!("zrt: lane input index {i} out of range")))
}
#[inline]
pub fn output_buffer(&self, i: usize) -> Result<&TensorBuffer<T>> {
self.outputs
.get(i)
.ok_or_else(|| Error::new(-1, format!("zrt: lane output index {i} out of range")))
}
}
impl<T: TensorElement> AllocatedOutputTensorIoLane<'_, T> {
#[inline]
pub fn run(&mut self) -> Result<()> {
self.session.run_binding(&self.binding)
}
pub fn prime(&mut self, runs: usize) -> Result<()> {
for _ in 0..runs {
self.run()?;
}
Ok(())
}
pub fn run_with_allocator_stats(
&mut self, allocator: &Allocator,
) -> Result<LaneRunAllocatorStats> {
let before = allocator.stats()?;
self.run()?;
let after = allocator.stats()?;
Ok(LaneRunAllocatorStats { before, after })
}
#[inline]
pub fn input(&self, i: usize) -> Result<&[T]> {
self.inputs
.get(i)
.map(TensorBuffer::as_slice)
.ok_or_else(|| {
Error::new(
-1,
format!("zrt: allocated-output lane input index {i} out of range"),
)
})
}
#[inline]
pub fn input_mut(&mut self, i: usize) -> Result<&mut [T]> {
self.inputs
.get_mut(i)
.map(TensorBuffer::as_mut_slice)
.ok_or_else(|| {
Error::new(
-1,
format!("zrt: allocated-output lane input index {i} out of range"),
)
})
}
#[inline]
pub fn output(&self, i: usize) -> Result<&[T]> {
self.outputs
.get(i)
.ok_or_else(|| {
Error::new(
-1,
format!("zrt: allocated-output lane output index {i} out of range"),
)
})?
.as_slice()
}
#[inline]
pub fn output_mut(&mut self, i: usize) -> Result<&mut [T]> {
self.outputs
.get_mut(i)
.ok_or_else(|| {
Error::new(
-1,
format!("zrt: allocated-output lane output index {i} out of range"),
)
})?
.as_mut_slice()
}
#[inline]
pub fn input_buffer(&self, i: usize) -> Result<&TensorBuffer<T>> {
self.inputs.get(i).ok_or_else(|| {
Error::new(
-1,
format!("zrt: allocated-output lane input index {i} out of range"),
)
})
}
#[inline]
pub fn output_tensor(&self, i: usize) -> Result<&AllocatedTensor<T>> {
self.outputs.get(i).ok_or_else(|| {
Error::new(
-1,
format!("zrt: allocated-output lane output index {i} out of range"),
)
})
}
}
impl<T: TensorElement> DeviceOutputTensorIoLane<'_, T> {
pub fn run(&mut self) -> Result<&[OwnedValue]> {
self.outputs.clear();
self.session.run_binding(&self.binding)?;
self.outputs = self.binding.output_values()?;
Ok(&self.outputs)
}
pub fn prime(&mut self, runs: usize) -> Result<()> {
for _ in 0..runs {
self.run()?;
}
Ok(())
}
pub fn run_with_allocator_stats(
&mut self, allocator: &Allocator,
) -> Result<LaneRunAllocatorStats> {
let before = allocator.stats()?;
self.run()?;
let after = allocator.stats()?;
Ok(LaneRunAllocatorStats { before, after })
}
#[inline]
pub fn outputs(&self) -> &[OwnedValue] {
&self.outputs
}
#[inline]
pub fn output(&self, i: usize) -> Result<&OwnedValue> {
self.outputs.get(i).ok_or_else(|| {
Error::new(
-1,
format!("zrt: device-output lane output index {i} out of range"),
)
})
}
#[inline]
pub fn input(&self, i: usize) -> Result<&[T]> {
self.inputs
.get(i)
.map(TensorBuffer::as_slice)
.ok_or_else(|| {
Error::new(
-1,
format!("zrt: device-output lane input index {i} out of range"),
)
})
}
#[inline]
pub fn input_mut(&mut self, i: usize) -> Result<&mut [T]> {
self.inputs
.get_mut(i)
.map(TensorBuffer::as_mut_slice)
.ok_or_else(|| {
Error::new(
-1,
format!("zrt: device-output lane input index {i} out of range"),
)
})
}
#[inline]
pub fn input_buffer(&self, i: usize) -> Result<&TensorBuffer<T>> {
self.inputs.get(i).ok_or_else(|| {
Error::new(
-1,
format!("zrt: device-output lane input index {i} out of range"),
)
})
}
}
impl<T: TensorElement> AllocatedTensorIoLane<'_, T> {
#[inline]
pub fn run(&mut self) -> Result<()> {
self.session.run_binding(&self.binding)
}
pub fn prime(&mut self, runs: usize) -> Result<()> {
for _ in 0..runs {
self.run()?;
}
Ok(())
}
pub fn run_with_allocator_stats(
&mut self, allocator: &Allocator,
) -> Result<LaneRunAllocatorStats> {
let before = allocator.stats()?;
self.run()?;
let after = allocator.stats()?;
Ok(LaneRunAllocatorStats { before, after })
}
#[inline]
pub fn input(&self, i: usize) -> Result<&[T]> {
self.inputs
.get(i)
.ok_or_else(|| {
Error::new(
-1,
format!("zrt: allocated tensor lane input index {i} out of range"),
)
})?
.as_slice()
}
#[inline]
pub fn input_mut(&mut self, i: usize) -> Result<&mut [T]> {
self.inputs
.get_mut(i)
.ok_or_else(|| {
Error::new(
-1,
format!("zrt: allocated tensor lane input index {i} out of range"),
)
})?
.as_mut_slice()
}
#[inline]
pub fn output(&self, i: usize) -> Result<&[T]> {
self.outputs
.get(i)
.ok_or_else(|| {
Error::new(
-1,
format!("zrt: allocated tensor lane output index {i} out of range"),
)
})?
.as_slice()
}
#[inline]
pub fn output_mut(&mut self, i: usize) -> Result<&mut [T]> {
self.outputs
.get_mut(i)
.ok_or_else(|| {
Error::new(
-1,
format!("zrt: allocated tensor lane output index {i} out of range"),
)
})?
.as_mut_slice()
}
#[inline]
pub fn input_tensor(&self, i: usize) -> Result<&AllocatedTensor<T>> {
self.inputs.get(i).ok_or_else(|| {
Error::new(
-1,
format!("zrt: allocated tensor lane input index {i} out of range"),
)
})
}
#[inline]
pub fn output_tensor(&self, i: usize) -> Result<&AllocatedTensor<T>> {
self.outputs.get(i).ok_or_else(|| {
Error::new(
-1,
format!("zrt: allocated tensor lane output index {i} out of range"),
)
})
}
}
impl<T: TensorElement, const INPUTS: usize, const OUTPUTS: usize>
StaticTensorIoLane<'_, T, INPUTS, OUTPUTS>
{
#[inline]
pub fn run(&mut self) -> Result<()> {
self.session.run_binding(&self.binding)
}
pub fn prime(&mut self, runs: usize) -> Result<()> {
for _ in 0..runs {
self.run()?;
}
Ok(())
}
pub fn run_with_allocator_stats(
&mut self, allocator: &Allocator,
) -> Result<LaneRunAllocatorStats> {
let before = allocator.stats()?;
self.run()?;
let after = allocator.stats()?;
Ok(LaneRunAllocatorStats { before, after })
}
#[inline]
pub fn inputs(&self) -> &[TensorBuffer<T>; INPUTS] {
&self.inputs
}
#[inline]
pub fn inputs_mut(&mut self) -> &mut [TensorBuffer<T>; INPUTS] {
&mut self.inputs
}
#[inline]
pub fn outputs(&self) -> &[TensorBuffer<T>; OUTPUTS] {
&self.outputs
}
#[inline]
pub fn outputs_mut(&mut self) -> &mut [TensorBuffer<T>; OUTPUTS] {
&mut self.outputs
}
#[inline]
pub fn input(&self, i: usize) -> Result<&[T]> {
self.inputs
.get(i)
.map(TensorBuffer::as_slice)
.ok_or_else(|| Error::new(-1, format!("zrt: lane input index {i} out of range")))
}
#[inline]
pub fn input_mut(&mut self, i: usize) -> Result<&mut [T]> {
self.inputs
.get_mut(i)
.map(TensorBuffer::as_mut_slice)
.ok_or_else(|| Error::new(-1, format!("zrt: lane input index {i} out of range")))
}
#[inline]
pub fn output(&self, i: usize) -> Result<&[T]> {
self.outputs
.get(i)
.map(TensorBuffer::as_slice)
.ok_or_else(|| Error::new(-1, format!("zrt: lane output index {i} out of range")))
}
#[inline]
pub fn output_mut(&mut self, i: usize) -> Result<&mut [T]> {
self.outputs
.get_mut(i)
.map(TensorBuffer::as_mut_slice)
.ok_or_else(|| Error::new(-1, format!("zrt: lane output index {i} out of range")))
}
#[inline]
pub fn input_buffer(&self, i: usize) -> Result<&TensorBuffer<T>> {
self.inputs
.get(i)
.ok_or_else(|| Error::new(-1, format!("zrt: lane input index {i} out of range")))
}
#[inline]
pub fn output_buffer(&self, i: usize) -> Result<&TensorBuffer<T>> {
self.outputs
.get(i)
.ok_or_else(|| Error::new(-1, format!("zrt: lane output index {i} out of range")))
}
}
impl Drop for Session {
fn drop(&mut self) {
unsafe {
if !self.sess.is_null() {
api().release_session()(self.sess);
}
}
}
}
unsafe impl Send for Session {}
unsafe impl Sync for Session {}
#[cfg(feature = "model-editor")]
impl Session {
pub fn from_model(
env: &Environment, model: &crate::model_editor::Model, opts: SessionOptions,
) -> Result<Self> {
let me = crate::model_editor::model_editor_api()
.ok_or_else(|| crate::Error::new(-1, "ModelEditorApi unavailable"))?;
let create = crate::model_editor::require_sub_api_fn(
me.CreateSessionFromModel,
"ModelEditorApi",
"CreateSessionFromModel",
)?;
let opts_handle = build_session_options_for_env(env, &opts)?;
let mut sess: *mut sys::SessionHandle = ptr::null_mut();
let create = check(unsafe {
create(
env.as_ptr(),
model.as_ptr(),
opts_handle as *const sys::SessionOptionsHandle,
&mut sess,
)
});
unsafe { api().release_session_options()(opts_handle) };
create?;
Self::from_handle(sess, env.share())
}
pub fn opset_for_domain(&self, domain: &str) -> Result<i32> {
let me = crate::model_editor::model_editor_api()
.ok_or_else(|| crate::Error::new(-1, "ModelEditorApi unavailable"))?;
let get_opset = crate::model_editor::require_sub_api_fn(
me.SessionGetOpsetForDomain,
"ModelEditorApi",
"SessionGetOpsetForDomain",
)?;
let cdom = CString::new(domain)?;
let mut opset: i32 = 0;
check(unsafe {
get_opset(
self.sess as *const sys::SessionHandle,
cdom.as_ptr(),
&mut opset,
)
})?;
Ok(opset)
}
pub fn from_bytes_for_editing(
env: &Environment, model_data: &[u8], opts: SessionOptions,
) -> Result<Self> {
let me = crate::model_editor::model_editor_api()
.ok_or_else(|| crate::Error::new(-1, "ModelEditorApi unavailable"))?;
let create = crate::model_editor::require_sub_api_fn(
me.CreateModelEditorSessionFromArray,
"ModelEditorApi",
"CreateModelEditorSessionFromArray",
)?;
let opts_handle = build_session_options_for_env(env, &opts)?;
let mut sess: *mut sys::SessionHandle = ptr::null_mut();
let create = check(unsafe {
create(
env.as_ptr(),
model_data.as_ptr() as *const c_void,
model_data.len(),
opts_handle as *const sys::SessionOptionsHandle,
&mut sess,
)
});
unsafe { api().release_session_options()(opts_handle) };
create?;
Self::from_handle(sess, env.share())
}
pub fn apply_model(&self, model: &crate::model_editor::Model) -> Result<()> {
let me = crate::model_editor::model_editor_api()
.ok_or_else(|| crate::Error::new(-1, "ModelEditorApi unavailable"))?;
let apply = crate::model_editor::require_sub_api_fn(
me.ApplyModelToModelEditorSession,
"ModelEditorApi",
"ApplyModelToModelEditorSession",
)?;
check(unsafe { apply(self.sess, model.as_ptr() as *mut sys::ModelHandle) })
}
pub fn finalize(&mut self, opts: &SessionOptions) -> Result<()> {
let me = crate::model_editor::model_editor_api()
.ok_or_else(|| crate::Error::new(-1, "ModelEditorApi unavailable"))?;
let finalize = crate::model_editor::require_sub_api_fn(
me.FinalizeModelEditorSession,
"ModelEditorApi",
"FinalizeModelEditorSession",
)?;
let opts_handle = opts.build_handle()?;
let r = check(unsafe {
finalize(
self.sess,
opts_handle as *const sys::SessionOptionsHandle,
ptr::null_mut(),
)
});
unsafe { api().release_session_options()(opts_handle) };
r?;
self.refresh_io_metadata()
}
}
struct AsyncState {
result: UnsafeCell<Option<Result<Vec<OwnedValue>>>>,
done: AtomicBool,
waker: AtomicWaker,
_in_handles: Box<[*const sys::ValueHandle]>,
_out_handles: Box<[*mut sys::ValueHandle]>,
}
unsafe impl Send for AsyncState {}
unsafe impl Sync for AsyncState {}
pub struct RunFuture<'a> {
state: Arc<AsyncState>,
_borrows: std::marker::PhantomData<&'a ()>,
}
impl<'a> std::future::Future for RunFuture<'a> {
type Output = Result<Vec<OwnedValue>>;
fn poll(
self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
if self.state.done.load(Ordering::Acquire) {
return std::task::Poll::Ready(self.state.take_result());
}
self.state.waker.register(cx.waker());
if self.state.done.load(Ordering::Acquire) {
std::task::Poll::Ready(self.state.take_result())
} else {
std::task::Poll::Pending
}
}
}
impl AsyncState {
fn complete(&self, result: Result<Vec<OwnedValue>>) {
unsafe { *self.result.get() = Some(result) };
self.done.store(true, Ordering::Release);
self.waker.wake();
}
fn take_result(&self) -> Result<Vec<OwnedValue>> {
unsafe { (*self.result.get()).take() }
.unwrap_or_else(|| Err(crate::Error::new(-1, "zrt: async result already consumed")))
}
}
#[allow(clippy::from_raw_with_void_ptr)] unsafe extern "C" fn run_async_callback(
user_data: *mut c_void, outputs: *mut *mut sys::ValueHandle, num_outputs: usize,
status: sys::StatusPtr,
) {
unsafe {
let state: Arc<AsyncState> = Arc::from_raw(user_data as *const AsyncState);
let result: Result<Vec<OwnedValue>> =
std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
if !status.is_null() {
return Err(match check(status) {
Err(e) => e,
Ok(()) => crate::Error::new(
sys::OrtErrorCode::Fail as i32,
"RunAsync returned a non-null but Ok status",
),
});
}
if outputs.is_null() {
return Ok(Vec::new());
}
let handles = std::slice::from_raw_parts(outputs, num_outputs);
OwnedValue::collect_from_raw(handles)
}))
.unwrap_or_else(|_| {
Err(crate::Error::new(
sys::OrtErrorCode::Fail as i32,
"panic in RunAsync callback",
))
});
state.complete(result);
}
}
fn add_owned_initializers(
opts: *mut sys::SessionOptionsHandle, initializers: &[OwnedInitializer],
) -> Result<()> {
for init in initializers {
check(unsafe { api().add_initializer()(opts, init.name_ptr(), init.value_ptr()) })?;
}
Ok(())
}
fn build_session_options_for_env(
env: &Environment, opts: &SessionOptions,
) -> Result<*mut sys::SessionOptionsHandle> {
let opts_handle = opts.build_handle()?;
let result = (|| {
apply_ep_device_attach_or_release(env, opts_handle, opts)?;
if env.has_global_thread_pool() && opts.use_global_thread_pool {
check(unsafe { api().disable_per_session_threads()(opts_handle) })?;
}
Ok(opts_handle)
})();
if result.is_err() {
unsafe { api().release_session_options()(opts_handle) };
}
result
}
fn apply_ep_device_attach_or_release(
env: &Environment, opts_handle: *mut sys::SessionOptionsHandle, opts: &SessionOptions,
) -> Result<()> {
#[cfg(feature = "ep")]
if let Err(err) =
crate::ep_device::apply_device_attach(env, opts_handle, &opts.ep_device_attach)
{
unsafe { api().release_session_options()(opts_handle) };
return Err(err);
}
let _ = (env, opts_handle, opts);
Ok(())
}
fn collect_io_names(
sess: *mut sys::SessionHandle, is_input: bool, alloc: &Allocator,
) -> Result<(Vec<CString>, Vec<*const c_char>)> {
let api = api();
let mut count: usize = 0;
check(unsafe {
if is_input {
api.session_get_input_count()(sess as *const sys::SessionHandle, &mut count)
} else {
api.session_get_output_count()(sess as *const sys::SessionHandle, &mut count)
}
})?;
let mut names = Vec::with_capacity(count);
for i in 0..count {
let mut raw: *mut c_char = ptr::null_mut();
check(unsafe {
if is_input {
api.session_get_input_name()(
sess as *const sys::SessionHandle,
i,
alloc.alloc,
&mut raw,
)
} else {
api.session_get_output_name()(
sess as *const sys::SessionHandle,
i,
alloc.alloc,
&mut raw,
)
}
})?;
if raw.is_null() {
return Err(Error::new(-1, "zrt: session I/O name pointer is null"));
}
let c = unsafe { CStr::from_ptr(raw).to_owned() };
unsafe { alloc.free(raw as *mut c_void) }?;
names.push(c);
}
let ptrs = names.iter().map(|c| c.as_ptr()).collect();
Ok((names, ptrs))
}
fn collect_io_meta(
sess: *mut sys::SessionHandle, is_input: bool, count: usize,
) -> Result<Vec<CachedIo>> {
let api = api();
let mut out = Vec::with_capacity(count);
for i in 0..count {
let mut type_info: *mut sys::TypeInfoHandle = ptr::null_mut();
let meta = (|| -> Result<CachedIo> {
check(unsafe {
if is_input {
api.session_get_input_type_info()(
sess as *const sys::SessionHandle,
i,
&mut type_info,
)
} else {
api.session_get_output_type_info()(
sess as *const sys::SessionHandle,
i,
&mut type_info,
)
}
})?;
let mut onnx_type = sys::OnnxType::Unknown;
check(unsafe {
api.get_onnx_type_from_type_info()(
type_info as *const sys::TypeInfoHandle,
&mut onnx_type,
)
})?;
if onnx_type == sys::OnnxType::Tensor {
let mut tensor_info: *const sys::TensorTypeAndShapeInfoHandle = ptr::null();
check(unsafe {
api.cast_type_info_to_tensor_info()(
type_info as *const sys::TypeInfoHandle,
&mut tensor_info,
)
})?;
let mut etype = sys::ElementType::Undefined;
check(unsafe { api.get_tensor_element_type()(tensor_info, &mut etype) })?;
let mut rank: usize = 0;
check(unsafe { api.get_dimensions_count()(tensor_info, &mut rank) })?;
let mut dims = vec![0i64; rank];
check(unsafe { api.get_dimensions()(tensor_info, dims.as_mut_ptr(), rank) })?;
let mut sptrs: Vec<*const c_char> = vec![ptr::null(); rank];
check(unsafe {
api.get_symbolic_dimensions()(tensor_info, sptrs.as_mut_ptr(), rank)
})?;
let symbolic = sptrs
.iter()
.map(|&p| {
if p.is_null() {
Ok(None)
} else {
unsafe { crate::cstr_to_string(p, "symbolic dimension") }.map(Some)
}
})
.collect::<Result<Vec<_>>>()?;
let count = crate::type_info::checked_element_count(&dims).ok();
Ok(CachedIo {
onnx_type,
elem_type: etype,
count,
dims,
symbolic,
})
} else {
Ok(CachedIo {
onnx_type,
elem_type: sys::ElementType::Undefined,
count: Some(0),
dims: Vec::new(),
symbolic: Vec::new(),
})
}
})();
if !type_info.is_null() {
unsafe { api.release_type_info()(type_info) };
}
out.push(meta?);
}
Ok(out)
}
#[cfg(test)]
mod tests {
}