use alloc::ffi::CString;
use super::{ExecutionProvider, RegisterError};
use crate::{ep::ArenaExtendStrategy, error::Result, session::builder::SessionBuilder};
#[derive(Debug, Clone)]
pub struct MIGraphX {
device_id: i32,
enable_fp16: bool,
enable_fp8: bool,
enable_int8: bool,
use_native_calibration_table: bool,
int8_calibration_table_name: Option<CString>,
save_model_path: Option<CString>,
load_model_path: Option<CString>,
exhaustive_tune: bool,
memory_limit: usize,
arena_extend_strategy: ArenaExtendStrategy
}
impl Default for MIGraphX {
fn default() -> Self {
Self {
device_id: 0,
enable_fp16: false,
enable_fp8: false,
enable_int8: false,
use_native_calibration_table: false,
int8_calibration_table_name: None,
save_model_path: None,
load_model_path: None,
exhaustive_tune: false,
memory_limit: usize::MAX,
arena_extend_strategy: ArenaExtendStrategy::NextPowerOfTwo
}
}
}
super::impl_ep!(MIGraphX);
impl MIGraphX {
#[must_use]
pub fn with_device_id(mut self, device_id: i32) -> Self {
self.device_id = device_id;
self
}
#[must_use]
pub fn with_fp16(mut self, enable: bool) -> Self {
self.enable_fp16 = enable;
self
}
#[must_use]
pub fn with_fp8(mut self, enable: bool) -> Self {
self.enable_fp8 = enable;
self
}
#[must_use]
pub fn with_int8(mut self, enable: bool) -> Self {
self.enable_int8 = enable;
self
}
#[must_use]
pub fn with_int8_calibration_table(mut self, table_name: impl AsRef<str>, native: bool) -> Self {
self.use_native_calibration_table = native;
self.int8_calibration_table_name = Some(CString::new(table_name.as_ref()).expect("invalid string"));
self
}
#[must_use]
pub fn with_save_model(mut self, path: impl AsRef<str>) -> Self {
self.save_model_path = Some(CString::new(path.as_ref()).expect("invalid string"));
self
}
#[must_use]
pub fn with_load_model(mut self, path: impl AsRef<str>) -> Self {
self.load_model_path = Some(CString::new(path.as_ref()).expect("invalid string"));
self
}
#[must_use]
pub fn with_exhaustive_tune(mut self, enable: bool) -> Self {
self.exhaustive_tune = enable;
self
}
#[must_use]
pub fn with_mem_limit(mut self, bytes: usize) -> Self {
self.memory_limit = bytes;
self
}
#[must_use]
pub fn with_arena_extend_strategy(mut self, strategy: ArenaExtendStrategy) -> Self {
self.arena_extend_strategy = strategy;
self
}
}
impl ExecutionProvider for MIGraphX {
fn name(&self) -> &'static str {
"MIGraphXExecutionProvider"
}
fn supported_by_platform(&self) -> bool {
cfg!(any(all(target_os = "linux", target_arch = "x86_64"), all(target_os = "windows", target_arch = "x86_64")))
}
#[allow(unused, unreachable_code)]
fn register(&self, session_builder: &mut SessionBuilder) -> Result<(), RegisterError> {
#[cfg(any(feature = "load-dynamic", feature = "migraphx"))]
{
use core::ptr;
use crate::{AsPointer, ortsys};
let options = ort_sys::OrtMIGraphXProviderOptions {
device_id: self.device_id,
migraphx_fp16_enable: self.enable_fp16.into(),
migraphx_fp8_enable: self.enable_fp8.into(),
migraphx_int8_enable: self.enable_int8.into(),
migraphx_use_native_calibration_table: self.use_native_calibration_table.into(),
migraphx_int8_calibration_table_name: self.int8_calibration_table_name.as_ref().map(|c| c.as_ptr()).unwrap_or_else(ptr::null),
migraphx_load_compiled_model: self.load_model_path.is_some().into(),
migraphx_load_model_path: self.load_model_path.as_ref().map(|c| c.as_ptr()).unwrap_or_else(ptr::null),
migraphx_save_compiled_model: self.save_model_path.is_some().into(),
migraphx_save_model_path: self.save_model_path.as_ref().map(|c| c.as_ptr()).unwrap_or_else(ptr::null),
migraphx_exhaustive_tune: self.exhaustive_tune,
migraphx_mem_limit: self.memory_limit as _,
migraphx_arena_extend_strategy: match self.arena_extend_strategy {
ArenaExtendStrategy::NextPowerOfTwo => 0,
ArenaExtendStrategy::SameAsRequested => 1
}
};
ortsys![unsafe SessionOptionsAppendExecutionProvider_MIGraphX(session_builder.ptr_mut(), &options)?];
return Ok(());
}
Err(RegisterError::MissingFeature)
}
}