ort 2.0.0-rc.12

A safe Rust wrapper for ONNX Runtime 1.24 - Optimize and accelerate machine learning inference & training
Documentation
use alloc::ffi::CString;

use super::{ExecutionProvider, RegisterError};
use crate::{ep::ArenaExtendStrategy, error::Result, session::builder::SessionBuilder};

/// [MIGraphX execution provider](https://onnxruntime.ai/docs/execution-providers/MIGraphX-ExecutionProvider.html) for
/// hardware acceleration with AMD GPUs.
#[derive(Debug, Clone)]
pub struct MIGraphX {
	device_id: i32,
	enable_fp16: bool,
	enable_fp8: bool,
	enable_int8: bool,
	use_native_calibration_table: bool,
	int8_calibration_table_name: Option<CString>,
	save_model_path: Option<CString>,
	load_model_path: Option<CString>,
	exhaustive_tune: bool,
	memory_limit: usize,
	arena_extend_strategy: ArenaExtendStrategy
}

impl Default for MIGraphX {
	fn default() -> Self {
		Self {
			device_id: 0,
			enable_fp16: false,
			enable_fp8: false,
			enable_int8: false,
			use_native_calibration_table: false,
			int8_calibration_table_name: None,
			save_model_path: None,
			load_model_path: None,
			exhaustive_tune: false,
			memory_limit: usize::MAX,
			arena_extend_strategy: ArenaExtendStrategy::NextPowerOfTwo
		}
	}
}

super::impl_ep!(MIGraphX);

impl MIGraphX {
	/// Configures which device the EP should use.
	///
	/// ```
	/// # use ort::{ep, session::Session};
	/// # fn main() -> ort::Result<()> {
	/// let ep = ep::MIGraphX::default().with_device_id(0).build();
	/// # Ok(())
	/// # }
	/// ```
	#[must_use]
	pub fn with_device_id(mut self, device_id: i32) -> Self {
		self.device_id = device_id;
		self
	}

	/// Enable FP16 quantization for the model.
	///
	/// ```
	/// # use ort::{ep, session::Session};
	/// # fn main() -> ort::Result<()> {
	/// let ep = ep::MIGraphX::default().with_fp16(true).build();
	/// # Ok(())
	/// # }
	/// ```
	#[must_use]
	pub fn with_fp16(mut self, enable: bool) -> Self {
		self.enable_fp16 = enable;
		self
	}

	/// Enable FP8 quantization for the model.
	///
	/// ```
	/// # use ort::{ep, session::Session};
	/// # fn main() -> ort::Result<()> {
	/// let ep = ep::MIGraphX::default().with_fp8(true).build();
	/// # Ok(())
	/// # }
	/// ```
	#[must_use]
	pub fn with_fp8(mut self, enable: bool) -> Self {
		self.enable_fp8 = enable;
		self
	}

	/// Enable 8-bit integer quantization for the model. Requires
	/// [`MIGraphX::with_int8_calibration_table`] to be set.
	///
	/// ```
	/// # use ort::{ep, session::Session};
	/// # fn main() -> ort::Result<()> {
	/// let ep = ep::MIGraphX::default().with_int8(true).with_int8_calibration_table("...", false).build();
	/// # Ok(())
	/// # }
	/// ```
	#[must_use]
	pub fn with_int8(mut self, enable: bool) -> Self {
		self.enable_int8 = enable;
		self
	}

	/// Configures the path to the input calibration data for int8 quantization.
	///
	/// The `native` parameter specifies the format the calibration data is in - `true` for native int8 format, `false`
	/// for the JSON dump format.
	///
	/// ```
	/// # use ort::{ep, session::Session};
	/// # fn main() -> ort::Result<()> {
	/// let ep = ep::MIGraphX::default().with_int8(true).with_int8_calibration_table("...", false).build();
	/// # Ok(())
	/// # }
	/// ```
	#[must_use]
	pub fn with_int8_calibration_table(mut self, table_name: impl AsRef<str>, native: bool) -> Self {
		self.use_native_calibration_table = native;
		self.int8_calibration_table_name = Some(CString::new(table_name.as_ref()).expect("invalid string"));
		self
	}

	/// Save the compiled MIGraphX model to the given path.
	///
	/// The compiled model can then be loaded in subsequent runs with [`MIGraphX::with_load_model`].
	///
	/// ```
	/// # use ort::{ep, session::Session};
	/// # fn main() -> ort::Result<()> {
	/// let ep = ep::MIGraphX::default().with_save_model("./compiled_model.mxr").build();
	/// # Ok(())
	/// # }
	/// ```
	#[must_use]
	pub fn with_save_model(mut self, path: impl AsRef<str>) -> Self {
		self.save_model_path = Some(CString::new(path.as_ref()).expect("invalid string"));
		self
	}

	/// Load the compiled MIGraphX model (previously generated by [`MIGraphX::with_save_model`]) from
	/// the given path.
	///
	/// ```
	/// # use ort::{ep, session::Session};
	/// # fn main() -> ort::Result<()> {
	/// let ep = ep::MIGraphX::default().with_load_model("./compiled_model.mxr").build();
	/// # Ok(())
	/// # }
	/// ```
	#[must_use]
	pub fn with_load_model(mut self, path: impl AsRef<str>) -> Self {
		self.load_model_path = Some(CString::new(path.as_ref()).expect("invalid string"));
		self
	}

	/// Enable exhaustive tuning; trades loading time for inference performance.
	///
	/// ```
	/// # use ort::{ep, session::Session};
	/// # fn main() -> ort::Result<()> {
	/// let ep = ep::MIGraphX::default().with_exhaustive_tune(true).build();
	/// # Ok(())
	/// # }
	/// ```
	#[must_use]
	pub fn with_exhaustive_tune(mut self, enable: bool) -> Self {
		self.exhaustive_tune = enable;
		self
	}

	#[must_use]
	pub fn with_mem_limit(mut self, bytes: usize) -> Self {
		self.memory_limit = bytes;
		self
	}

	#[must_use]
	pub fn with_arena_extend_strategy(mut self, strategy: ArenaExtendStrategy) -> Self {
		self.arena_extend_strategy = strategy;
		self
	}
}

impl ExecutionProvider for MIGraphX {
	fn name(&self) -> &'static str {
		"MIGraphXExecutionProvider"
	}

	fn supported_by_platform(&self) -> bool {
		cfg!(any(all(target_os = "linux", target_arch = "x86_64"), all(target_os = "windows", target_arch = "x86_64")))
	}

	#[allow(unused, unreachable_code)]
	fn register(&self, session_builder: &mut SessionBuilder) -> Result<(), RegisterError> {
		#[cfg(any(feature = "load-dynamic", feature = "migraphx"))]
		{
			use core::ptr;

			use crate::{AsPointer, ortsys};

			let options = ort_sys::OrtMIGraphXProviderOptions {
				device_id: self.device_id,
				migraphx_fp16_enable: self.enable_fp16.into(),
				migraphx_fp8_enable: self.enable_fp8.into(),
				migraphx_int8_enable: self.enable_int8.into(),
				migraphx_use_native_calibration_table: self.use_native_calibration_table.into(),
				migraphx_int8_calibration_table_name: self.int8_calibration_table_name.as_ref().map(|c| c.as_ptr()).unwrap_or_else(ptr::null),
				migraphx_load_compiled_model: self.load_model_path.is_some().into(),
				migraphx_load_model_path: self.load_model_path.as_ref().map(|c| c.as_ptr()).unwrap_or_else(ptr::null),
				migraphx_save_compiled_model: self.save_model_path.is_some().into(),
				migraphx_save_model_path: self.save_model_path.as_ref().map(|c| c.as_ptr()).unwrap_or_else(ptr::null),
				migraphx_exhaustive_tune: self.exhaustive_tune,
				migraphx_mem_limit: self.memory_limit as _,
				migraphx_arena_extend_strategy: match self.arena_extend_strategy {
					ArenaExtendStrategy::NextPowerOfTwo => 0,
					ArenaExtendStrategy::SameAsRequested => 1
				}
			};
			ortsys![unsafe SessionOptionsAppendExecutionProvider_MIGraphX(session_builder.ptr_mut(), &options)?];
			return Ok(());
		}

		Err(RegisterError::MissingFeature)
	}
}