use alloc::{
borrow::Cow,
sync::{Arc, Weak},
vec::Vec
};
use core::{
any::Any,
ptr::{self, NonNull}
};
use smallvec::SmallVec;
use crate::{
AsPointer, Error,
environment::{self, Environment},
error::Result,
logging::LoggerFunction,
memory::MemoryInfo,
operator::OperatorDomain,
ortsys,
util::with_cstr,
value::DynValue
};
#[cfg(feature = "api-22")]
#[cfg_attr(docsrs, doc(cfg(feature = "api-22")))]
mod editable;
mod impl_commit;
mod impl_config_keys;
mod impl_options;
#[cfg(feature = "api-22")]
#[cfg_attr(docsrs, doc(cfg(feature = "api-22")))]
pub use self::editable::*;
pub use self::impl_options::*;
pub type BuilderResult = Result<SessionBuilder, Error<SessionBuilder>>;
pub struct SessionBuilder {
session_options_ptr: Arc<SessionOptionsPointer>,
memory_info: Option<Arc<MemoryInfo>>,
operator_domains: SmallVec<[Arc<OperatorDomain>; 1]>,
initializers: Vec<Arc<DynValue>>,
external_initializer_buffers: Vec<Cow<'static, [u8]>>,
prepacked_weights: Option<PrepackedWeights>,
thread_manager: Option<Arc<dyn Any>>,
logger: Option<Arc<LoggerFunction>>,
no_global_thread_pool: bool,
no_env_eps: bool,
pub(crate) environment: Arc<Environment>
}
impl Clone for SessionBuilder {
fn clone(&self) -> Self {
let mut session_options_ptr = ptr::null_mut();
ortsys![
unsafe CloneSessionOptions(self.ptr(), ptr::addr_of_mut!(session_options_ptr))
.expect("error cloning session options");
nonNull(session_options_ptr)
];
Self {
session_options_ptr: Arc::new(SessionOptionsPointer::new(session_options_ptr)),
memory_info: self.memory_info.clone(),
operator_domains: self.operator_domains.clone(),
initializers: self.initializers.clone(),
external_initializer_buffers: self.external_initializer_buffers.clone(),
prepacked_weights: self.prepacked_weights.clone(),
thread_manager: self.thread_manager.clone(),
logger: self.logger.clone(),
no_global_thread_pool: self.no_global_thread_pool,
no_env_eps: self.no_env_eps,
environment: self.environment.clone()
}
}
}
impl SessionBuilder {
pub fn new() -> Result<Self> {
let environment = environment::current()?;
let mut session_options_ptr: *mut ort_sys::OrtSessionOptions = ptr::null_mut();
ortsys![unsafe CreateSessionOptions(&mut session_options_ptr)?; nonNull(session_options_ptr)];
#[cfg(feature = "api-22")]
let _ = ortsys![@ort: unsafe SessionOptionsSetEpSelectionPolicy(session_options_ptr.as_ptr(), AutoDevicePolicy::MaxEfficiency.into()) as Result];
Ok(Self {
session_options_ptr: Arc::new(SessionOptionsPointer::new(session_options_ptr)),
memory_info: None,
operator_domains: SmallVec::new(),
initializers: Vec::new(),
external_initializer_buffers: Vec::new(),
prepacked_weights: None,
thread_manager: None,
logger: None,
no_global_thread_pool: false,
no_env_eps: false,
environment
})
}
#[inline]
pub(crate) fn add_config_entry(&mut self, key: impl AsRef<str>, value: impl AsRef<str>) -> Result<()> {
let ptr = self.ptr_mut();
with_cstr(key.as_ref().as_bytes(), &|key| {
with_cstr(value.as_ref().as_bytes(), &|value| {
ortsys![unsafe AddSessionConfigEntry(ptr, key.as_ptr(), value.as_ptr())?];
Ok(())
})
})
}
#[cfg(feature = "api-22")]
#[cfg_attr(docsrs, doc(cfg(feature = "api-22")))]
pub fn canceler(&self) -> LoadCanceler {
LoadCanceler(Arc::downgrade(&self.session_options_ptr))
}
pub fn with_config_entry(mut self, key: impl AsRef<str>, value: impl AsRef<str>) -> BuilderResult {
match self.add_config_entry(key.as_ref(), value.as_ref()) {
Ok(()) => Ok(self),
Err(e) => Err(e.with_recover(self))
}
}
}
impl AsPointer for SessionBuilder {
type Sys = ort_sys::OrtSessionOptions;
fn ptr(&self) -> *const Self::Sys {
self.session_options_ptr.as_ptr()
}
}
#[derive(Debug, Clone)]
#[cfg(feature = "api-22")]
#[cfg_attr(docsrs, doc(cfg(feature = "api-22")))]
pub struct LoadCanceler(Weak<SessionOptionsPointer>);
#[cfg(feature = "api-22")]
unsafe impl Send for LoadCanceler {}
#[cfg(feature = "api-22")]
unsafe impl Sync for LoadCanceler {}
#[cfg(feature = "api-22")]
impl LoadCanceler {
#[cfg(feature = "api-22")]
#[cfg_attr(docsrs, doc(cfg(feature = "api-22")))]
pub fn cancel(&self) -> Result<()> {
if let Some(ptr) = self.0.upgrade() {
ortsys![unsafe SessionOptionsSetLoadCancellationFlag(ptr.as_ptr(), true)?];
}
Ok(())
}
}
#[derive(Debug)]
#[repr(transparent)]
pub(crate) struct SessionOptionsPointer(NonNull<ort_sys::OrtSessionOptions>);
impl SessionOptionsPointer {
#[inline]
pub(crate) fn new(ptr: NonNull<ort_sys::OrtSessionOptions>) -> Self {
crate::logging::create!(SessionBuilder, ptr);
Self(ptr)
}
#[inline]
pub(crate) fn as_ptr(&self) -> *mut ort_sys::OrtSessionOptions {
self.0.as_ptr()
}
}
impl Drop for SessionOptionsPointer {
fn drop(&mut self) {
ortsys![unsafe ReleaseSessionOptions(self.0.as_ptr())];
crate::logging::drop!(SessionBuilder, self.0.as_ptr());
}
}