ort 2.0.0-rc.12

A safe Rust wrapper for ONNX Runtime 1.24 - Optimize and accelerate machine learning inference & training
Documentation
use alloc::{
	borrow::Cow,
	sync::{Arc, Weak},
	vec::Vec
};
use core::{
	any::Any,
	ptr::{self, NonNull}
};

use smallvec::SmallVec;

use crate::{
	AsPointer, Error,
	environment::{self, Environment},
	error::Result,
	logging::LoggerFunction,
	memory::MemoryInfo,
	operator::OperatorDomain,
	ortsys,
	util::with_cstr,
	value::DynValue
};

#[cfg(feature = "api-22")]
#[cfg_attr(docsrs, doc(cfg(feature = "api-22")))]
mod editable;
mod impl_commit;
mod impl_config_keys;
mod impl_options;

#[cfg(feature = "api-22")]
#[cfg_attr(docsrs, doc(cfg(feature = "api-22")))]
pub use self::editable::*;
pub use self::impl_options::*;

/// `Result` type returned by [`SessionBuilder`] methods.
///
/// This type supports [error recovery](Error::recover):
/// ```
/// # use ort::session::{builder::GraphOptimizationLevel, Session};
/// # fn main() -> ort::Result<()> {
/// let session = Session::builder()?
/// 	.with_optimization_level(GraphOptimizationLevel::All)
/// 	// Optimization isn't enabled in minimal builds of ONNX Runtime, so throws an error. We can just ignore it.
/// 	.unwrap_or_else(|e| e.recover())
/// 	.commit_from_file("tests/data/upsample.onnx")?;
/// # Ok(())
/// # }
/// ```
pub type BuilderResult = Result<SessionBuilder, Error<SessionBuilder>>;

/// Creates a session using the builder pattern.
///
/// Once configured, use the
/// [`SessionBuilder::commit_from_file`](crate::session::builder::SessionBuilder::commit_from_file) method to 'commit'
/// the builder configuration into a [`Session`].
///
/// ```
/// # use ort::session::{builder::GraphOptimizationLevel, Session};
/// # fn main() -> ort::Result<()> {
/// let session = Session::builder()?
/// 	.with_optimization_level(GraphOptimizationLevel::Level1)?
/// 	.with_intra_threads(1)?
/// 	.commit_from_file("tests/data/upsample.onnx")?;
/// # Ok(())
/// # }
/// ```
///
/// [`Session`]: crate::session::Session
pub struct SessionBuilder {
	session_options_ptr: Arc<SessionOptionsPointer>,
	memory_info: Option<Arc<MemoryInfo>>,
	operator_domains: SmallVec<[Arc<OperatorDomain>; 1]>,
	initializers: Vec<Arc<DynValue>>,
	external_initializer_buffers: Vec<Cow<'static, [u8]>>,
	prepacked_weights: Option<PrepackedWeights>,
	thread_manager: Option<Arc<dyn Any>>,
	logger: Option<Arc<LoggerFunction>>,
	no_global_thread_pool: bool,
	no_env_eps: bool,
	pub(crate) environment: Arc<Environment>
}

impl Clone for SessionBuilder {
	fn clone(&self) -> Self {
		let mut session_options_ptr = ptr::null_mut();
		ortsys![
			unsafe CloneSessionOptions(self.ptr(), ptr::addr_of_mut!(session_options_ptr))
				.expect("error cloning session options");
			nonNull(session_options_ptr)
		];
		Self {
			session_options_ptr: Arc::new(SessionOptionsPointer::new(session_options_ptr)),
			memory_info: self.memory_info.clone(),
			operator_domains: self.operator_domains.clone(),
			initializers: self.initializers.clone(),
			external_initializer_buffers: self.external_initializer_buffers.clone(),
			prepacked_weights: self.prepacked_weights.clone(),
			thread_manager: self.thread_manager.clone(),
			logger: self.logger.clone(),
			no_global_thread_pool: self.no_global_thread_pool,
			no_env_eps: self.no_env_eps,
			environment: self.environment.clone()
		}
	}
}

impl SessionBuilder {
	/// Creates a new session builder.
	///
	/// ```
	/// # use ort::session::{builder::GraphOptimizationLevel, Session};
	/// # fn main() -> ort::Result<()> {
	/// let session = Session::builder()?
	/// 	.with_optimization_level(GraphOptimizationLevel::Level1)?
	/// 	.with_intra_threads(1)?
	/// 	.commit_from_file("tests/data/upsample.onnx")?;
	/// # Ok(())
	/// # }
	/// ```
	pub fn new() -> Result<Self> {
		let environment = environment::current()?;

		let mut session_options_ptr: *mut ort_sys::OrtSessionOptions = ptr::null_mut();
		ortsys![unsafe CreateSessionOptions(&mut session_options_ptr)?; nonNull(session_options_ptr)];

		// target on-device usage; prefer efficiency by default
		// .with_execution_providers/.with_auto_ep will override this
		#[cfg(feature = "api-22")]
		let _ = ortsys![@ort: unsafe SessionOptionsSetEpSelectionPolicy(session_options_ptr.as_ptr(), AutoDevicePolicy::MaxEfficiency.into()) as Result];

		Ok(Self {
			session_options_ptr: Arc::new(SessionOptionsPointer::new(session_options_ptr)),
			memory_info: None,
			operator_domains: SmallVec::new(),
			initializers: Vec::new(),
			external_initializer_buffers: Vec::new(),
			prepacked_weights: None,
			thread_manager: None,
			logger: None,
			no_global_thread_pool: false,
			no_env_eps: false,
			environment
		})
	}

	#[inline]
	pub(crate) fn add_config_entry(&mut self, key: impl AsRef<str>, value: impl AsRef<str>) -> Result<()> {
		let ptr = self.ptr_mut();
		with_cstr(key.as_ref().as_bytes(), &|key| {
			with_cstr(value.as_ref().as_bytes(), &|value| {
				ortsys![unsafe AddSessionConfigEntry(ptr, key.as_ptr(), value.as_ptr())?];
				Ok(())
			})
		})
	}

	/// Creates a signaler that can be used from another thread to cancel any in-progress commits.
	///
	/// ```
	/// # use ort::session::{builder::GraphOptimizationLevel, Session};
	/// # use std::{thread, time::Duration};
	/// # fn main() -> ort::Result<()> {
	/// let mut builder = Session::builder()?
	/// 	.with_optimization_level(GraphOptimizationLevel::Level1)?
	/// 	.with_intra_threads(1)?;
	///
	/// let canceler = builder.canceler();
	/// thread::spawn(move || {
	/// 	thread::sleep(Duration::from_millis(500));
	/// 	// timeout if model hasn't loaded in 500ms
	/// 	let _ = canceler.cancel();
	/// });
	///
	/// let session = builder.commit_from_file("tests/data/upsample.onnx")?;
	/// # Ok(())
	/// # }
	/// ```
	#[cfg(feature = "api-22")]
	#[cfg_attr(docsrs, doc(cfg(feature = "api-22")))]
	pub fn canceler(&self) -> LoadCanceler {
		LoadCanceler(Arc::downgrade(&self.session_options_ptr))
	}

	/// Adds a custom configuration entry to the session.
	pub fn with_config_entry(mut self, key: impl AsRef<str>, value: impl AsRef<str>) -> BuilderResult {
		match self.add_config_entry(key.as_ref(), value.as_ref()) {
			Ok(()) => Ok(self),
			Err(e) => Err(e.with_recover(self))
		}
	}
}

impl AsPointer for SessionBuilder {
	type Sys = ort_sys::OrtSessionOptions;

	fn ptr(&self) -> *const Self::Sys {
		self.session_options_ptr.as_ptr()
	}
}

/// A handle which can be used to remotely terminate an in-progress session load.
///
/// See [`SessionBuilder::canceler`].
#[derive(Debug, Clone)]
#[cfg(feature = "api-22")]
#[cfg_attr(docsrs, doc(cfg(feature = "api-22")))]
pub struct LoadCanceler(Weak<SessionOptionsPointer>);

#[cfg(feature = "api-22")]
unsafe impl Send for LoadCanceler {}
#[cfg(feature = "api-22")]
unsafe impl Sync for LoadCanceler {}

#[cfg(feature = "api-22")]
impl LoadCanceler {
	/// Cancels any active session commits.
	///
	/// ```
	/// # use ort::session::{builder::GraphOptimizationLevel, Session};
	/// # use std::{thread, time::Duration};
	/// # fn main() -> ort::Result<()> {
	/// let mut builder = Session::builder()?
	/// 	.with_optimization_level(GraphOptimizationLevel::Level1)?
	/// 	.with_intra_threads(1)?;
	///
	/// let canceler = builder.canceler();
	/// thread::spawn(move || {
	/// 	thread::sleep(Duration::from_millis(500));
	/// 	// timeout if model hasn't loaded in 500ms
	/// 	let _ = canceler.cancel();
	/// });
	///
	/// let session = builder.commit_from_file("tests/data/upsample.onnx")?;
	/// # Ok(())
	/// # }
	/// ```
	#[cfg(feature = "api-22")]
	#[cfg_attr(docsrs, doc(cfg(feature = "api-22")))]
	pub fn cancel(&self) -> Result<()> {
		if let Some(ptr) = self.0.upgrade() {
			ortsys![unsafe SessionOptionsSetLoadCancellationFlag(ptr.as_ptr(), true)?];
		}
		Ok(())
	}
}

#[derive(Debug)]
#[repr(transparent)]
pub(crate) struct SessionOptionsPointer(NonNull<ort_sys::OrtSessionOptions>);

impl SessionOptionsPointer {
	#[inline]
	pub(crate) fn new(ptr: NonNull<ort_sys::OrtSessionOptions>) -> Self {
		crate::logging::create!(SessionBuilder, ptr);
		Self(ptr)
	}

	#[inline]
	pub(crate) fn as_ptr(&self) -> *mut ort_sys::OrtSessionOptions {
		self.0.as_ptr()
	}
}

impl Drop for SessionOptionsPointer {
	fn drop(&mut self) {
		ortsys![unsafe ReleaseSessionOptions(self.0.as_ptr())];
		crate::logging::drop!(SessionBuilder, self.0.as_ptr());
	}
}