#![allow(unused)]
use std::{
collections::HashMap,
ffi::{c_void, CString},
os::raw::c_char,
ptr
};
use crate::{
error::status_to_result,
ortsys,
sys::{self, size_t, OrtArenaCfg},
OrtApiError, OrtError, OrtResult
};
#[cfg(all(not(feature = "load-dynamic"), not(target_arch = "x86")))]
extern "C" {
#[cfg(feature = "acl")]
pub(crate) fn OrtSessionOptionsAppendExecutionProvider_ACL(options: *mut sys::OrtSessionOptions, use_arena: std::os::raw::c_int) -> sys::OrtStatusPtr;
#[cfg(feature = "onednn")]
pub(crate) fn OrtSessionOptionsAppendExecutionProvider_Dnnl(options: *mut sys::OrtSessionOptions, use_arena: std::os::raw::c_int) -> sys::OrtStatusPtr;
#[cfg(feature = "coreml")]
pub(crate) fn OrtSessionOptionsAppendExecutionProvider_CoreML(options: *mut sys::OrtSessionOptions, flags: u32) -> sys::OrtStatusPtr;
#[cfg(feature = "directml")]
pub(crate) fn OrtSessionOptionsAppendExecutionProvider_DML(options: *mut sys::OrtSessionOptions, device_id: std::os::raw::c_int) -> sys::OrtStatusPtr;
#[cfg(feature = "nnapi")]
pub(crate) fn OrtSessionOptionsAppendExecutionProvider_Nnapi(options: *mut sys::OrtSessionOptions, flags: u32) -> sys::OrtStatusPtr;
#[cfg(feature = "tvm")]
pub(crate) fn OrtSessionOptionsAppendExecutionProvider_Tvm(options: *mut sys::OrtSessionOptions, opt_str: *const std::os::raw::c_char)
-> sys::OrtStatusPtr;
}
#[derive(Debug, Clone, Default)]
pub struct CPUExecutionProviderOptions {
pub use_arena: bool
}
#[derive(Debug, Clone)]
pub enum ArenaExtendStrategy {
NextPowerOfTwo,
SameAsRequested
}
impl Default for ArenaExtendStrategy {
fn default() -> Self {
Self::NextPowerOfTwo
}
}
#[derive(Debug, Clone)]
pub enum CUDAExecutionProviderCuDNNConvAlgoSearch {
Exhaustive,
Heuristic,
Default
}
impl Default for CUDAExecutionProviderCuDNNConvAlgoSearch {
fn default() -> Self {
Self::Exhaustive
}
}
#[derive(Debug, Clone)]
pub struct CUDAExecutionProviderOptions {
pub device_id: u32,
pub gpu_mem_limit: usize,
pub arena_extend_strategy: ArenaExtendStrategy,
pub cudnn_conv_algo_search: CUDAExecutionProviderCuDNNConvAlgoSearch,
pub do_copy_in_default_stream: bool,
pub cudnn_conv_use_max_workspace: bool,
pub cudnn_conv1d_pad_to_nc1d: bool,
pub enable_cuda_graph: bool,
pub enable_skip_layer_norm_strict_mode: bool
}
impl Default for CUDAExecutionProviderOptions {
fn default() -> Self {
Self {
device_id: 0,
gpu_mem_limit: usize::MAX,
arena_extend_strategy: ArenaExtendStrategy::NextPowerOfTwo,
cudnn_conv_algo_search: CUDAExecutionProviderCuDNNConvAlgoSearch::Exhaustive,
do_copy_in_default_stream: true,
cudnn_conv_use_max_workspace: true,
cudnn_conv1d_pad_to_nc1d: false,
enable_cuda_graph: false,
enable_skip_layer_norm_strict_mode: false
}
}
}
#[derive(Debug, Clone)]
pub struct TensorRTExecutionProviderOptions {
pub device_id: u32,
pub max_workspace_size: u32,
pub max_partition_iterations: u32,
pub min_subgraph_size: u32,
pub fp16_enable: bool,
pub int8_enable: bool,
pub int8_calibration_table_name: String,
pub int8_use_native_calibration_table: bool,
pub dla_enable: bool,
pub dla_core: u32,
pub engine_cache_enable: bool,
pub engine_cache_path: String,
pub dump_subgraphs: bool,
pub force_sequential_engine_build: bool,
pub enable_context_memory_sharing: bool,
pub layer_norm_fp32_fallback: bool,
pub timing_cache_enable: bool,
pub force_timing_cache: bool,
pub detailed_build_log: bool,
pub enable_build_heuristics: bool,
pub enable_sparsity: bool,
pub builder_optimization_level: u8,
pub auxiliary_streams: i8,
pub tactic_sources: String,
pub extra_plugin_lib_paths: String,
pub profile_min_shapes: String,
pub profile_max_shapes: String,
pub profile_opt_shapes: String
}
impl Default for TensorRTExecutionProviderOptions {
fn default() -> Self {
Self {
device_id: 0,
max_workspace_size: 1073741824,
max_partition_iterations: 1000,
min_subgraph_size: 1,
fp16_enable: false,
int8_enable: false,
int8_calibration_table_name: String::new(),
int8_use_native_calibration_table: false,
dla_enable: false,
dla_core: 0,
engine_cache_enable: false,
engine_cache_path: String::new(),
dump_subgraphs: false,
force_sequential_engine_build: false,
enable_context_memory_sharing: false,
layer_norm_fp32_fallback: false,
timing_cache_enable: false,
force_timing_cache: false,
detailed_build_log: false,
enable_build_heuristics: false,
enable_sparsity: false,
builder_optimization_level: 3,
auxiliary_streams: -1,
tactic_sources: String::new(),
extra_plugin_lib_paths: String::new(),
profile_min_shapes: String::new(),
profile_max_shapes: String::new(),
profile_opt_shapes: String::new()
}
}
}
#[derive(Debug, Clone)]
pub struct OpenVINOExecutionProviderOptions {
pub device_type: Option<String>,
pub device_id: Option<String>,
pub num_threads: size_t,
pub cache_dir: Option<String>,
pub context: *mut c_void,
pub enable_opencl_throttling: bool,
pub enable_dynamic_shapes: bool,
pub enable_vpu_fast_compile: bool
}
impl Default for OpenVINOExecutionProviderOptions {
fn default() -> Self {
Self {
device_type: None,
device_id: None,
num_threads: 8,
cache_dir: None,
context: std::ptr::null_mut(),
enable_opencl_throttling: false,
enable_dynamic_shapes: false,
enable_vpu_fast_compile: false
}
}
}
#[derive(Debug, Clone, Default)]
pub struct OneDNNExecutionProviderOptions {
pub use_arena: bool
}
#[derive(Debug, Clone, Default)]
pub struct ACLExecutionProviderOptions {
pub use_arena: bool
}
#[derive(Debug, Clone, Default)]
pub struct CoreMLExecutionProviderOptions {
pub use_cpu_only: bool,
pub enable_on_subgraph: bool,
pub only_enable_device_with_ane: bool
}
#[derive(Debug, Clone, Default)]
pub struct DirectMLExecutionProviderOptions {
pub device_id: u32
}
#[derive(Debug, Clone, Default)]
pub struct ROCmExecutionProviderOptions {
pub device_id: i32,
pub miopen_conv_exhaustive_search: bool,
pub gpu_mem_limit: size_t,
pub arena_extend_strategy: ArenaExtendStrategy,
pub do_copy_in_default_stream: bool,
pub user_compute_stream: Option<*mut c_void>,
pub default_memory_arena_cfg: Option<*mut sys::OrtArenaCfg>,
pub tunable_op_enable: bool,
pub tunable_op_tuning_enable: bool,
pub tunable_op_max_tuning_duration_ms: i32
}
#[derive(Debug, Clone, Default)]
pub struct NNAPIExecutionProviderOptions {
pub use_fp16: bool,
pub use_nchw: bool,
pub disable_cpu: bool,
pub cpu_only: bool
}
#[derive(Debug, Clone)]
pub enum QNNExecutionHTPPerformanceMode {
Default,
Burst,
Balanced,
HighPerformance,
HighPowerSaver,
LowPowerSaver,
LowBalanced,
PowerSaver,
SustainedHighPerformance
}
impl QNNExecutionHTPPerformanceMode {
fn as_str(&self) -> &'static str {
match self {
QNNExecutionHTPPerformanceMode::Default => "default",
QNNExecutionHTPPerformanceMode::Burst => "burst",
QNNExecutionHTPPerformanceMode::Balanced => "balanced",
QNNExecutionHTPPerformanceMode::HighPerformance => "high_performance",
QNNExecutionHTPPerformanceMode::HighPowerSaver => "high_power_saver",
QNNExecutionHTPPerformanceMode::LowPowerSaver => "low_power_saver",
QNNExecutionHTPPerformanceMode::LowBalanced => "low_balanced",
QNNExecutionHTPPerformanceMode::PowerSaver => "power_saver",
QNNExecutionHTPPerformanceMode::SustainedHighPerformance => "sustained_high_performance"
}
}
}
#[derive(Debug, Clone)]
pub struct QNNExecutionProviderOptions {
pub backend_path: String,
pub qnn_context_cache_enable: bool,
pub qnn_context_cache_path: Option<String>,
pub profiling_level: Option<String>,
pub rpc_control_latency: Option<u32>,
pub htp_performance_mode: Option<QNNExecutionHTPPerformanceMode>
}
impl Default for QNNExecutionProviderOptions {
fn default() -> Self {
Self {
backend_path: String::from("libQnnHtp.so"),
qnn_context_cache_enable: false,
qnn_context_cache_path: Some(String::from("model_file.onnx.bin")),
profiling_level: Some(String::from("off")),
rpc_control_latency: Some(10),
htp_performance_mode: Some(QNNExecutionHTPPerformanceMode::Default)
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TVMExecutorType {
GraphExecutor,
VirtualMachine
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TVMTuningType {
AutoTVM,
Ansor
}
#[derive(Default, Debug, Clone)]
pub struct TVMExecutionProviderOptions {
pub executor: Option<TVMExecutorType>,
pub so_folder: Option<String>,
pub check_hash: Option<bool>,
pub hash_file_path: Option<String>,
pub target: Option<String>,
pub target_host: Option<String>,
pub opt_level: Option<usize>,
pub freeze_weights: Option<bool>,
pub to_nhwc: Option<bool>,
pub tuning_type: Option<TVMTuningType>,
pub tuning_file_path: Option<String>,
pub input_names: Option<String>,
pub input_shapes: Option<String>
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum CANNExecutionProviderPrecisionMode {
ForceFP32,
ForceFP16,
AllowFP32ToFP16,
MustKeepOrigin,
AllowMixedPrecision
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum CANNExecutionProviderImplementationMode {
HighPrecision,
HighPerformance
}
#[derive(Debug, Clone)]
pub struct CANNExecutionProviderOptions {
pub device_id: u32,
pub npu_mem_limit: usize,
pub arena_extend_strategy: ArenaExtendStrategy,
pub enable_cann_graph: bool,
pub dump_graphs: bool,
pub precision_mode: CANNExecutionProviderPrecisionMode,
pub op_select_impl_mode: CANNExecutionProviderImplementationMode,
pub optypelist_for_impl_mode: String
}
impl Default for CANNExecutionProviderOptions {
fn default() -> Self {
Self {
device_id: 0,
npu_mem_limit: usize::MAX,
arena_extend_strategy: ArenaExtendStrategy::NextPowerOfTwo,
enable_cann_graph: true,
dump_graphs: false,
precision_mode: CANNExecutionProviderPrecisionMode::ForceFP16,
op_select_impl_mode: CANNExecutionProviderImplementationMode::HighPerformance,
optypelist_for_impl_mode: String::default()
}
}
}
macro_rules! get_ep_register {
($symbol:ident($($id:ident: $type:ty),*) -> $rt:ty) => {
#[cfg(feature = "load-dynamic")]
#[allow(non_snake_case)]
let $symbol = unsafe {
use crate::G_ORT_LIB;
let dylib = *G_ORT_LIB
.lock()
.expect("failed to acquire ONNX Runtime dylib lock; another thread panicked?")
.get_mut();
let symbol: Result<
libloading::Symbol<unsafe extern "C" fn($($id: $type),*) -> $rt>,
libloading::Error
> = (*dylib).get(stringify!($symbol).as_bytes());
match symbol {
Ok(symbol) => symbol,
Err(e) => {
return Err(OrtError::DlLoad { symbol: stringify!($symbol), error: e.to_string() });
}
}
};
};
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum ExecutionProvider {
CPU(CPUExecutionProviderOptions),
CUDA(CUDAExecutionProviderOptions),
TensorRT(TensorRTExecutionProviderOptions),
OpenVINO(OpenVINOExecutionProviderOptions),
ACL(ACLExecutionProviderOptions),
OneDNN(OneDNNExecutionProviderOptions),
CoreML(CoreMLExecutionProviderOptions),
DirectML(DirectMLExecutionProviderOptions),
ROCm(ROCmExecutionProviderOptions),
NNAPI(NNAPIExecutionProviderOptions),
QNN(QNNExecutionProviderOptions),
TVM(TVMExecutionProviderOptions),
CANN(CANNExecutionProviderOptions)
}
macro_rules! map_keys {
($($fn_name:ident = $ex:expr),*) => {
{
let mut keys = Vec::<CString>::new();
let mut values = Vec::<CString>::new();
$(
let str_value = CString::new(($ex).to_string().as_str()).unwrap();
if !str_value.is_empty() {
keys.push(CString::new(stringify!($fn_name)).unwrap());
values.push(str_value);
}
)*
assert_eq!(keys.len(), values.len()); let key_ptrs: Vec<*const c_char> = keys.iter().map(|k| k.as_ptr()).collect();
let value_ptrs: Vec<*const c_char> = values.iter().map(|v| v.as_ptr()).collect();
(key_ptrs, value_ptrs, keys.len(), keys, values)
}
};
}
#[inline]
fn bool_as_int(x: bool) -> i32 {
match x {
true => 1,
false => 0
}
}
impl ExecutionProvider {
pub fn as_str(&self) -> &'static str {
match self {
Self::CPU(_) => "CPUExecutionProvider",
Self::CUDA(_) => "CUDAExecutionProvider",
Self::TensorRT(_) => "TensorrtExecutionProvider",
Self::OpenVINO(_) => "OpenVINOExecutionProvider",
Self::ACL(_) => "AclExecutionProvider",
Self::OneDNN(_) => "DnnlExecutionProvider",
Self::CoreML(_) => "CoreMLExecutionProvider",
Self::DirectML(_) => "DmlExecutionProvider",
Self::ROCm(_) => "ROCmExecutionProvider",
Self::NNAPI(_) => "NnapiExecutionProvider",
Self::QNN(_) => "QNNExecutionProvider",
Self::TVM(_) => "TvmExecutionProvider",
Self::CANN(_) => "CANNExecutionProvider"
}
}
pub fn is_available(&self) -> bool {
let mut providers: *mut *mut c_char = std::ptr::null_mut();
let mut num_providers = 0;
if status_to_result(ortsys![unsafe GetAvailableProviders(&mut providers, &mut num_providers)]).is_err() {
return false;
}
for i in 0..num_providers {
let avail = unsafe { std::ffi::CStr::from_ptr(*providers.offset(i as isize)) }
.to_string_lossy()
.into_owned();
if self.as_str() == avail {
let _ = ortsys![unsafe ReleaseAvailableProviders(providers, num_providers)];
return true;
}
}
let _ = ortsys![unsafe ReleaseAvailableProviders(providers, num_providers)];
false
}
pub(crate) fn apply(&self, session_options: *mut sys::OrtSessionOptions) -> OrtResult<()> {
match &self {
&Self::CPU(options) => {
if options.use_arena {
status_to_result(ortsys![unsafe EnableCpuMemArena(session_options)]).map_err(OrtError::ExecutionProvider)?;
} else {
status_to_result(ortsys![unsafe DisableCpuMemArena(session_options)]).map_err(OrtError::ExecutionProvider)?;
}
}
#[cfg(any(feature = "load-dynamic", feature = "cuda"))]
&Self::CUDA(options) => {
let mut cuda_options: *mut sys::OrtCUDAProviderOptionsV2 = std::ptr::null_mut();
status_to_result(ortsys![unsafe CreateCUDAProviderOptions(&mut cuda_options)]).map_err(OrtError::ExecutionProvider)?;
let (key_ptrs, value_ptrs, len, keys, values) = map_keys! {
device_id = options.device_id,
arena_extend_strategy = match options.arena_extend_strategy {
ArenaExtendStrategy::NextPowerOfTwo => "kNextPowerOfTwo",
ArenaExtendStrategy::SameAsRequested => "kSameAsRequested"
},
cudnn_conv_algo_search = match options.cudnn_conv_algo_search {
CUDAExecutionProviderCuDNNConvAlgoSearch::Exhaustive => "EXHAUSTIVE",
CUDAExecutionProviderCuDNNConvAlgoSearch::Heuristic => "HEURISTIC",
CUDAExecutionProviderCuDNNConvAlgoSearch::Default => "DEFAULT"
},
do_copy_in_default_stream = bool_as_int(options.do_copy_in_default_stream),
cudnn_conv_use_max_workspace = bool_as_int(options.cudnn_conv_use_max_workspace),
cudnn_conv1d_pad_to_nc1d = bool_as_int(options.cudnn_conv1d_pad_to_nc1d),
enable_cuda_graph = bool_as_int(options.enable_cuda_graph),
enable_skip_layer_norm_strict_mode = bool_as_int(options.enable_skip_layer_norm_strict_mode)
};
if let Err(e) = status_to_result(ortsys![unsafe UpdateCUDAProviderOptions(cuda_options, key_ptrs.as_ptr(), value_ptrs.as_ptr(), len as _)])
.map_err(OrtError::ExecutionProvider)
{
ortsys![unsafe ReleaseCUDAProviderOptions(cuda_options)];
std::mem::drop((keys, values));
return Err(e);
}
let status = ortsys![unsafe SessionOptionsAppendExecutionProvider_CUDA_V2(session_options, cuda_options)];
ortsys![unsafe ReleaseCUDAProviderOptions(cuda_options)];
std::mem::drop((keys, values));
status_to_result(status).map_err(OrtError::ExecutionProvider)?;
}
#[cfg(any(feature = "load-dynamic", feature = "tensorrt"))]
&Self::TensorRT(options) => {
let mut trt_options: *mut sys::OrtTensorRTProviderOptionsV2 = std::ptr::null_mut();
status_to_result(ortsys![unsafe CreateTensorRTProviderOptions(&mut trt_options)]).map_err(OrtError::ExecutionProvider)?;
let (key_ptrs, value_ptrs, len, keys, values) = map_keys! {
device_id = options.device_id,
trt_max_workspace_size = options.max_workspace_size,
trt_max_partition_iterations = options.max_partition_iterations,
trt_min_subgraph_size = options.min_subgraph_size,
trt_fp16_enable = bool_as_int(options.fp16_enable),
trt_int8_enable = bool_as_int(options.int8_enable),
trt_int8_calibration_table_name = options.int8_calibration_table_name,
trt_dla_enable = bool_as_int(options.dla_enable),
trt_dla_core = options.dla_core,
trt_engine_cache_enable = bool_as_int(options.engine_cache_enable),
trt_engine_cache_path = options.engine_cache_path,
trt_dump_subgraphs = bool_as_int(options.dump_subgraphs),
trt_force_sequential_engine_build = bool_as_int(options.force_sequential_engine_build),
trt_context_memory_sharing_enable = bool_as_int(options.enable_context_memory_sharing),
trt_layer_norm_fp32_fallback = bool_as_int(options.layer_norm_fp32_fallback),
trt_timing_cache_enable = bool_as_int(options.timing_cache_enable),
trt_force_timing_cache_match = bool_as_int(options.force_timing_cache),
trt_detailed_build_log = bool_as_int(options.detailed_build_log),
trt_build_heuristics_enable = bool_as_int(options.enable_build_heuristics),
trt_sparsity_enable = bool_as_int(options.enable_sparsity),
trt_builder_optimization_level = options.builder_optimization_level,
trt_auxiliary_streams = options.auxiliary_streams,
trt_tactic_sources = options.tactic_sources,
trt_extra_plugin_lib_paths = options.extra_plugin_lib_paths,
trt_profile_min_shapes = options.profile_min_shapes,
trt_profile_max_shapes = options.profile_max_shapes,
trt_profile_opt_shapes = options.profile_opt_shapes
};
if let Err(e) = status_to_result(ortsys![unsafe UpdateTensorRTProviderOptions(trt_options, key_ptrs.as_ptr(), value_ptrs.as_ptr(), len as _)])
.map_err(OrtError::ExecutionProvider)
{
ortsys![unsafe ReleaseTensorRTProviderOptions(trt_options)];
std::mem::drop((keys, values));
return Err(e);
}
let status = ortsys![unsafe SessionOptionsAppendExecutionProvider_TensorRT_V2(session_options, trt_options)];
ortsys![unsafe ReleaseTensorRTProviderOptions(trt_options)];
std::mem::drop((keys, values));
status_to_result(status).map_err(OrtError::ExecutionProvider)?;
}
#[cfg(any(feature = "load-dynamic", feature = "acl"))]
&Self::ACL(options) => {
get_ep_register!(OrtSessionOptionsAppendExecutionProvider_ACL(options: *mut sys::OrtSessionOptions, use_arena: std::os::raw::c_int) -> sys::OrtStatusPtr);
status_to_result(unsafe { OrtSessionOptionsAppendExecutionProvider_ACL(session_options, options.use_arena.into()) })
.map_err(OrtError::ExecutionProvider)?;
}
#[cfg(any(feature = "load-dynamic", feature = "onednn"))]
&Self::OneDNN(options) => {
get_ep_register!(OrtSessionOptionsAppendExecutionProvider_Dnnl(options: *mut sys::OrtSessionOptions, use_arena: std::os::raw::c_int) -> sys::OrtStatusPtr);
status_to_result(unsafe { OrtSessionOptionsAppendExecutionProvider_Dnnl(session_options, options.use_arena.into()) })
.map_err(OrtError::ExecutionProvider)?;
}
#[cfg(any(feature = "load-dynamic", feature = "coreml"))]
&Self::CoreML(options) => {
get_ep_register!(OrtSessionOptionsAppendExecutionProvider_CoreML(options: *mut sys::OrtSessionOptions, flags: u32) -> sys::OrtStatusPtr);
let mut flags = 0;
if options.use_cpu_only {
flags |= 0x001;
}
if options.enable_on_subgraph {
flags |= 0x002;
}
if options.only_enable_device_with_ane {
flags |= 0x004;
}
status_to_result(unsafe { OrtSessionOptionsAppendExecutionProvider_CoreML(session_options, flags) }).map_err(OrtError::ExecutionProvider)?;
}
#[cfg(any(feature = "load-dynamic", feature = "directml"))]
&Self::DirectML(options) => {
get_ep_register!(OrtSessionOptionsAppendExecutionProvider_DML(options: *mut sys::OrtSessionOptions, device_id: std::os::raw::c_int) -> sys::OrtStatusPtr);
status_to_result(unsafe { OrtSessionOptionsAppendExecutionProvider_DML(session_options, options.device_id as _) })
.map_err(OrtError::ExecutionProvider)?;
}
#[cfg(any(feature = "load-dynamic", feature = "rocm"))]
&Self::ROCm(options) => {
let rocm_options = sys::OrtROCMProviderOptions {
device_id: options.device_id,
miopen_conv_exhaustive_search: bool_as_int(options.miopen_conv_exhaustive_search),
gpu_mem_limit: options.gpu_mem_limit as _,
arena_extend_strategy: match options.arena_extend_strategy {
ArenaExtendStrategy::NextPowerOfTwo => 0,
ArenaExtendStrategy::SameAsRequested => 1
},
do_copy_in_default_stream: bool_as_int(options.do_copy_in_default_stream),
has_user_compute_stream: bool_as_int(options.user_compute_stream.is_some()),
user_compute_stream: options.user_compute_stream.unwrap_or(ptr::null_mut()),
default_memory_arena_cfg: options.default_memory_arena_cfg.unwrap_or(ptr::null_mut()),
tunable_op_enable: bool_as_int(options.tunable_op_enable),
tunable_op_tuning_enable: bool_as_int(options.tunable_op_tuning_enable),
tunable_op_max_tuning_duration_ms: options.tunable_op_max_tuning_duration_ms
};
status_to_result(ortsys![unsafe SessionOptionsAppendExecutionProvider_ROCM(session_options, &rocm_options as *const _)])
.map_err(OrtError::ExecutionProvider)?;
}
#[cfg(any(feature = "load-dynamic", feature = "nnapi"))]
&Self::NNAPI(options) => {
get_ep_register!(OrtSessionOptionsAppendExecutionProvider_Nnapi(options: *mut sys::OrtSessionOptions, flags: u32) -> sys::OrtStatusPtr);
let mut flags = 0;
if options.use_fp16 {
flags |= 0x001;
}
if options.use_nchw {
flags |= 0x002;
}
if options.disable_cpu {
flags |= 0x004;
}
if options.cpu_only {
flags |= 0x008;
}
status_to_result(unsafe { OrtSessionOptionsAppendExecutionProvider_Nnapi(session_options, flags) }).map_err(OrtError::ExecutionProvider)?;
}
#[cfg(any(feature = "load-dynamic", feature = "openvino"))]
&Self::OpenVINO(options) => {
let openvino_options = sys::OrtOpenVINOProviderOptions {
device_type: options
.device_type
.clone()
.map(|x| x.as_bytes().as_ptr() as *const c_char)
.unwrap_or_else(ptr::null),
device_id: options
.device_id
.clone()
.map(|x| x.as_bytes().as_ptr() as *const c_char)
.unwrap_or_else(ptr::null),
num_of_threads: options.num_threads,
cache_dir: options
.cache_dir
.clone()
.map(|x| x.as_bytes().as_ptr() as *const c_char)
.unwrap_or_else(ptr::null),
context: options.context,
enable_opencl_throttling: bool_as_int(options.enable_opencl_throttling) as _,
enable_dynamic_shapes: bool_as_int(options.enable_dynamic_shapes) as _,
enable_vpu_fast_compile: bool_as_int(options.enable_vpu_fast_compile) as _
};
status_to_result(ortsys![unsafe SessionOptionsAppendExecutionProvider_OpenVINO(session_options, &openvino_options as *const _)])
.map_err(OrtError::ExecutionProvider)?;
}
#[cfg(any(feature = "load-dynamic", feature = "qnn"))]
&Self::QNN(options) => {
let (key_ptrs, value_ptrs, len, keys, values) = map_keys! {
backend_path = options.backend_path,
profiling_level = options.profiling_level.clone().unwrap_or("off".to_string()),
qnn_context_cache_enable = bool_as_int(options.qnn_context_cache_enable),
qnn_context_cache_path = options.qnn_context_cache_path.clone().unwrap_or("model_file.onnx.bin".to_string()),
htp_performance_mode = options.htp_performance_mode.clone().unwrap_or(QNNExecutionHTPPerformanceMode::Default).as_str(),
rpc_control_latency = options.rpc_control_latency.unwrap_or(10)
};
let name = CString::new("QNN").unwrap();
status_to_result(ortsys![unsafe SessionOptionsAppendExecutionProvider(
session_options,
name.as_ptr(),
key_ptrs.as_ptr(),
value_ptrs.as_ptr(),
len as _,
)])
.map_err(OrtError::ExecutionProvider)?;
}
#[cfg(any(feature = "load-dynamic", feature = "tvm"))]
&Self::TVM(options) => {
get_ep_register!(OrtSessionOptionsAppendExecutionProvider_Tvm(options: *mut sys::OrtSessionOptions, opt_str: *const std::os::raw::c_char) -> sys::OrtStatusPtr);
let mut option_string = Vec::new();
if let Some(check_hash) = options.check_hash {
option_string.push(format!("check_hash:{}", if check_hash { "True" } else { "False" }));
}
if let Some(executor) = options.executor {
option_string.push(format!(
"executor:{}",
match executor {
TVMExecutorType::GraphExecutor => "graph",
TVMExecutorType::VirtualMachine => "vm"
}
));
}
if let Some(freeze_weights) = options.freeze_weights {
option_string.push(format!("freeze_weights:{}", if freeze_weights { "True" } else { "False" }));
}
if let Some(hash_file_path) = options.hash_file_path.as_ref() {
option_string.push(format!("hash_file_path:{hash_file_path}"));
}
if let Some(input_names) = options.input_names.as_ref() {
option_string.push(format!("input_names:{input_names}"));
}
if let Some(input_shapes) = options.input_shapes.as_ref() {
option_string.push(format!("input_shapes:{input_shapes}"));
}
if let Some(opt_level) = options.opt_level {
option_string.push(format!("opt_level:{opt_level}"));
}
if let Some(so_folder) = options.so_folder.as_ref() {
option_string.push(format!("so_folder:{so_folder}"));
}
if let Some(target) = options.target.as_ref() {
option_string.push(format!("target:{target}"));
}
if let Some(target_host) = options.target_host.as_ref() {
option_string.push(format!("target_host:{target_host}"));
}
if let Some(to_nhwc) = options.to_nhwc {
option_string.push(format!("to_nhwc:{}", if to_nhwc { "True" } else { "False" }));
}
let options_string = CString::new(option_string.join(",")).unwrap();
status_to_result(unsafe { OrtSessionOptionsAppendExecutionProvider_Tvm(session_options, options_string.as_ptr()) })
.map_err(OrtError::ExecutionProvider)?;
}
#[cfg(any(feature = "load-dynamic", feature = "cann"))]
&Self::CANN(options) => {
let mut cann_options: *mut sys::OrtCANNProviderOptions = std::ptr::null_mut();
status_to_result(ortsys![unsafe CreateCANNProviderOptions(&mut cann_options)]).map_err(OrtError::ExecutionProvider)?;
let (key_ptrs, value_ptrs, len, keys, values) = map_keys! {
device_id = options.device_id,
npu_mem_limit = options.npu_mem_limit,
arena_extend_strategy = match options.arena_extend_strategy {
ArenaExtendStrategy::NextPowerOfTwo => "kNextPowerOfTwo",
ArenaExtendStrategy::SameAsRequested => "kSameAsRequested"
},
enable_cann_graph = bool_as_int(options.enable_cann_graph),
dump_graphs = bool_as_int(options.dump_graphs),
precision_mode = match options.precision_mode {
CANNExecutionProviderPrecisionMode::ForceFP32 => "force_fp32",
CANNExecutionProviderPrecisionMode::ForceFP16 => "force_fp16",
CANNExecutionProviderPrecisionMode::AllowFP32ToFP16 => "allow_fp32_to_fp16",
CANNExecutionProviderPrecisionMode::MustKeepOrigin => "must_keep_origin_dtype",
CANNExecutionProviderPrecisionMode::AllowMixedPrecision => "allow_mix_precision"
},
op_select_impl_mode = match options.op_select_impl_mode {
CANNExecutionProviderImplementationMode::HighPrecision => "high_precision",
CANNExecutionProviderImplementationMode::HighPerformance => "high_performance"
},
optypelist_for_impl_mode = options.optypelist_for_impl_mode
};
if let Err(e) = status_to_result(ortsys![unsafe UpdateCANNProviderOptions(cann_options, key_ptrs.as_ptr(), value_ptrs.as_ptr(), len as _)])
.map_err(OrtError::ExecutionProvider)
{
ortsys![unsafe ReleaseCANNProviderOptions(cann_options)];
std::mem::drop((keys, values));
return Err(e);
}
let status = ortsys![unsafe SessionOptionsAppendExecutionProvider_CANN(session_options, cann_options)];
ortsys![unsafe ReleaseCANNProviderOptions(cann_options)];
std::mem::drop((keys, values));
status_to_result(status).map_err(OrtError::ExecutionProvider)?;
}
_ => return Err(OrtError::ExecutionProviderNotRegistered(self.as_str()))
}
Ok(())
}
}
#[tracing::instrument(skip_all)]
pub(crate) fn apply_execution_providers(options: *mut sys::OrtSessionOptions, execution_providers: impl AsRef<[ExecutionProvider]>) {
for ex in execution_providers.as_ref() {
if let Err(e) = ex.apply(options) {
if let &OrtError::ExecutionProviderNotRegistered(_) = &e {
tracing::debug!("{}", e);
} else {
tracing::warn!("An error occurred when attempting to register `{}`: {e}", ex.as_str());
}
} else {
tracing::info!("Successfully registered `{}`", ex.as_str());
return;
}
}
tracing::warn!("No execution providers registered successfully. Falling back to CPU.");
}