use alloc::string::String;
use super::{ExecutionProvider, RegisterError};
use crate::{error::Result, session::builder::SessionBuilder};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ExecutorType {
GraphExecutor,
VirtualMachine
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TuningType {
AutoTVM,
Ansor
}
#[derive(Debug, Default, Clone)]
pub struct TVM {
pub executor: Option<ExecutorType>,
pub so_folder: Option<String>,
pub check_hash: Option<bool>,
pub hash_file_path: Option<String>,
pub target: Option<String>,
pub target_host: Option<String>,
pub opt_level: Option<usize>,
pub freeze_weights: Option<bool>,
pub to_nhwc: Option<bool>,
pub tuning_type: Option<TuningType>,
pub tuning_file_path: Option<String>,
pub input_names: Option<String>,
pub input_shapes: Option<String>
}
super::impl_ep!(TVM);
impl ExecutionProvider for TVM {
fn name(&self) -> &'static str {
"TvmExecutionProvider"
}
#[allow(unused, unreachable_code)]
fn register(&self, session_builder: &mut SessionBuilder) -> Result<(), RegisterError> {
#[cfg(any(feature = "load-dynamic", feature = "tvm"))]
{
use alloc::format;
use crate::AsPointer;
super::define_ep_register!(OrtSessionOptionsAppendExecutionProvider_Tvm(options: *mut ort_sys::OrtSessionOptions, opt_str: *const core::ffi::c_char) -> ort_sys::OrtStatusPtr);
let mut option_string = Vec::new();
if let Some(check_hash) = self.check_hash {
option_string.push(format!("check_hash:{}", if check_hash { "True" } else { "False" }));
}
if let Some(executor) = self.executor {
option_string.push(format!(
"executor:{}",
match executor {
ExecutorType::GraphExecutor => "graph",
ExecutorType::VirtualMachine => "vm"
}
));
}
if let Some(freeze_weights) = self.freeze_weights {
option_string.push(format!("freeze_weights:{}", if freeze_weights { "True" } else { "False" }));
}
if let Some(hash_file_path) = self.hash_file_path.as_ref() {
option_string.push(format!("hash_file_path:{hash_file_path}"));
}
if let Some(input_names) = self.input_names.as_ref() {
option_string.push(format!("input_names:{input_names}"));
}
if let Some(input_shapes) = self.input_shapes.as_ref() {
option_string.push(format!("input_shapes:{input_shapes}"));
}
if let Some(opt_level) = self.opt_level {
option_string.push(format!("opt_level:{opt_level}"));
}
if let Some(so_folder) = self.so_folder.as_ref() {
option_string.push(format!("so_folder:{so_folder}"));
}
if let Some(target) = self.target.as_ref() {
option_string.push(format!("target:{target}"));
}
if let Some(target_host) = self.target_host.as_ref() {
option_string.push(format!("target_host:{target_host}"));
}
if let Some(to_nhwc) = self.to_nhwc {
option_string.push(format!("to_nhwc:{}", if to_nhwc { "True" } else { "False" }));
}
let options_string = alloc::ffi::CString::new(option_string.join(",")).expect("invalid option string");
return Ok(unsafe {
crate::error::Error::result_from_status(OrtSessionOptionsAppendExecutionProvider_Tvm(session_builder.ptr_mut(), options_string.as_ptr()))
}?);
}
Err(RegisterError::MissingFeature)
}
}