use alloc::{string::String, sync::Arc, vec::Vec};
use core::{
ffi::{CStr, c_char, c_int},
marker::PhantomData,
mem,
ptr::{self, NonNull}
};
use smallvec::SmallVec;
#[cfg(feature = "api-20")]
use crate::session::adapter::{Adapter, AdapterInner};
use crate::{
AsPointer,
error::Result,
logging::LogLevel,
ortsys,
session::Outlet,
util::{MiniMap, STACK_SESSION_OUTPUTS, with_cstr},
value::{DynValue, Value, ValueTypeMarker}
};
#[derive(Debug)]
pub struct OutputSelector {
use_defaults: bool,
default_blocklist: Vec<String>,
allowlist: Vec<String>,
preallocated_outputs: MiniMap<String, Value>
}
impl Default for OutputSelector {
fn default() -> Self {
Self {
use_defaults: true,
allowlist: Vec::new(),
default_blocklist: Vec::new(),
preallocated_outputs: MiniMap::new()
}
}
}
impl OutputSelector {
pub fn no_default() -> Self {
Self {
use_defaults: false,
..Default::default()
}
}
pub fn with(mut self, name: impl Into<String>) -> Self {
self.allowlist.push(name.into());
self
}
pub fn without(mut self, name: impl Into<String>) -> Self {
self.default_blocklist.push(name.into());
self
}
pub fn preallocate<T: ValueTypeMarker>(mut self, name: impl Into<String>, value: Value<T>) -> Self {
self.preallocated_outputs.insert(name.into(), value.into_dyn());
self
}
pub(crate) fn resolve_outputs<'a, 's: 'a>(
&'a self,
outputs: &'s [Outlet]
) -> (SmallVec<[&'a str; STACK_SESSION_OUTPUTS]>, SmallVec<[Option<DynValue>; STACK_SESSION_OUTPUTS]>) {
if self.use_defaults { outputs.iter() } else { [].iter() }
.map(|o| o.name())
.filter(|n| !self.default_blocklist.iter().any(|e| e == n))
.chain(self.allowlist.iter().map(|x| x.as_str()))
.map(|n| (n, self.preallocated_outputs.get(n).map(DynValue::clone_of)))
.unzip()
}
}
pub trait SelectedOutputMarker {}
pub struct NoSelectedOutputs;
impl SelectedOutputMarker for NoSelectedOutputs {}
pub struct HasSelectedOutputs;
impl SelectedOutputMarker for HasSelectedOutputs {}
#[derive(Debug)]
pub(crate) struct UntypedRunOptions {
pub(crate) ptr: NonNull<ort_sys::OrtRunOptions>,
pub(crate) outputs: OutputSelector,
#[cfg(feature = "api-20")]
adapters: Vec<Arc<AdapterInner>>
}
impl UntypedRunOptions {
pub fn terminate(&self) -> Result<()> {
ortsys![unsafe RunOptionsSetTerminate(self.ptr.as_ptr())?];
Ok(())
}
}
unsafe impl Send for UntypedRunOptions {}
impl Drop for UntypedRunOptions {
fn drop(&mut self) {
ortsys![unsafe ReleaseRunOptions(self.ptr.as_ptr())];
crate::logging::drop!(RunOptions, self.ptr);
}
}
#[derive(Debug)]
pub struct RunOptions<O: SelectedOutputMarker = NoSelectedOutputs> {
pub(crate) inner: Arc<UntypedRunOptions>,
_marker: PhantomData<O>
}
unsafe impl<O: SelectedOutputMarker> Send for RunOptions<O> {}
unsafe impl Sync for RunOptions<NoSelectedOutputs> {}
impl RunOptions {
pub fn new() -> Result<RunOptions<NoSelectedOutputs>> {
let mut ptr: *mut ort_sys::OrtRunOptions = ptr::null_mut();
ortsys![unsafe CreateRunOptions(&mut ptr)?; nonNull(ptr)];
crate::logging::create!(RunOptions, ptr);
Ok(RunOptions {
inner: Arc::new(UntypedRunOptions {
ptr,
outputs: OutputSelector::default(),
#[cfg(feature = "api-20")]
adapters: Vec::new()
}),
_marker: PhantomData
})
}
}
impl<O: SelectedOutputMarker> RunOptions<O> {
pub fn with_outputs(mut self, outputs: OutputSelector) -> RunOptions<HasSelectedOutputs> {
let Some(inner) = Arc::get_mut(&mut self.inner) else {
panic!("Expected RunOptions to have exclusive access");
};
inner.outputs = outputs;
unsafe { mem::transmute(self) }
}
pub fn with_tag(mut self, tag: impl AsRef<str>) -> Result<Self> {
self.set_tag(tag).map(|_| self)
}
pub fn set_tag(&mut self, tag: impl AsRef<str>) -> Result<()> {
with_cstr(tag.as_ref().as_bytes(), &|tag| {
ortsys![unsafe RunOptionsSetRunTag(self.inner.ptr.as_ptr(), tag.as_ptr())?];
Ok(())
})
}
pub fn tag(&self) -> Result<&str> {
let mut tag_ptr: *const c_char = ptr::null();
ortsys![unsafe RunOptionsGetRunTag(self.inner.ptr.as_ptr(), &mut tag_ptr)?];
Ok(unsafe { CStr::from_ptr(tag_ptr) }.to_str()?)
}
pub fn terminate(&self) -> Result<()> {
self.inner.terminate()
}
pub fn unterminate(&self) -> Result<()> {
ortsys![unsafe RunOptionsUnsetTerminate(self.inner.ptr.as_ptr())?];
Ok(())
}
pub fn add_config_entry(&mut self, key: impl AsRef<str>, value: impl AsRef<str>) -> Result<()> {
with_cstr(key.as_ref().as_bytes(), &|key| {
with_cstr(value.as_ref().as_bytes(), &|value| {
ortsys![unsafe AddRunConfigEntry(self.inner.ptr.as_ptr(), key.as_ptr(), value.as_ptr())?];
Ok(())
})
})
}
#[cfg(feature = "api-20")]
#[cfg_attr(docsrs, doc(cfg(feature = "api-20")))]
pub fn add_adapter(&mut self, adapter: &Adapter) -> Result<()> {
let Some(inner) = Arc::get_mut(&mut self.inner) else {
panic!("Expected RunOptions to have exclusive access");
};
ortsys![unsafe RunOptionsAddActiveLoraAdapter(inner.ptr.as_ptr(), adapter.ptr())?];
inner.adapters.push(Arc::clone(&adapter.inner));
Ok(())
}
pub fn set_log_level(&mut self, level: LogLevel) -> Result<()> {
ortsys![unsafe RunOptionsSetRunLogSeverityLevel(self.ptr_mut(), ort_sys::OrtLoggingLevel::from(level) as _)?];
Ok(())
}
pub fn log_level(&self) -> Result<LogLevel> {
let mut log_level = ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE;
ortsys![unsafe RunOptionsGetRunLogSeverityLevel(self.ptr(), &mut log_level as *mut ort_sys::OrtLoggingLevel as *mut _)?];
Ok(LogLevel::from(log_level))
}
pub fn set_log_verbosity(&mut self, verbosity: c_int) -> Result<()> {
ortsys![unsafe RunOptionsSetRunLogVerbosityLevel(self.ptr_mut(), verbosity)?];
Ok(())
}
pub fn log_verbosity(&self) -> Result<i32> {
let mut verbosity = 0;
ortsys![unsafe RunOptionsGetRunLogVerbosityLevel(self.ptr(), &mut verbosity)?];
Ok(verbosity)
}
pub fn disable_device_sync(&mut self) -> Result<()> {
self.add_config_entry("disable_synchronize_execution_providers", "1")
}
}
impl<O: SelectedOutputMarker> AsPointer for RunOptions<O> {
type Sys = ort_sys::OrtRunOptions;
fn ptr(&self) -> *const Self::Sys {
self.inner.ptr.as_ptr()
}
}